blob: 2ea69e252a2403f4263b690610720745d56804ed [file] [log] [blame]
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
package x509
import (
"bytes"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
const (
testDir = "testdata"
testDirCN = "test-dir"
testFile = "test-file.crt"
testFileCN = "test-file"
testMissing = "missing"
)
func TestEnvVars(t *testing.T) {
testCases := []struct {
name string
fileEnv string
dirEnv string
files []string
dirs []string
cns []string
}{
{
// Environment variables override the default locations preventing fall through.
name: "override-defaults",
fileEnv: testMissing,
dirEnv: testMissing,
files: []string{testFile},
dirs: []string{testDir},
cns: nil,
},
{
// File environment overrides default file locations.
name: "file",
fileEnv: testFile,
dirEnv: "",
files: nil,
dirs: nil,
cns: []string{testFileCN},
},
{
// Directory environment overrides default directory locations.
name: "dir",
fileEnv: "",
dirEnv: testDir,
files: nil,
dirs: nil,
cns: []string{testDirCN},
},
{
// File & directory environment overrides both default locations.
name: "file+dir",
fileEnv: testFile,
dirEnv: testDir,
files: nil,
dirs: nil,
cns: []string{testFileCN, testDirCN},
},
{
// Environment variable empty / unset uses default locations.
name: "empty-fall-through",
fileEnv: "",
dirEnv: "",
files: []string{testFile},
dirs: []string{testDir},
cns: []string{testFileCN, testDirCN},
},
}
// Save old settings so we can restore before the test ends.
origCertFiles, origCertDirectories := certFiles, certDirectories
origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
defer func() {
certFiles = origCertFiles
certDirectories = origCertDirectories
os.Setenv(certFileEnv, origFile)
os.Setenv(certDirEnv, origDir)
}()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil {
t.Fatalf("setenv %q failed: %v", certFileEnv, err)
}
if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil {
t.Fatalf("setenv %q failed: %v", certDirEnv, err)
}
certFiles, certDirectories = tc.files, tc.dirs
r, err := loadSystemRoots()
if err != nil {
t.Fatal("unexpected failure:", err)
}
if r == nil {
t.Fatal("nil roots")
}
// Verify that the returned certs match, otherwise report where the mismatch is.
for i, cn := range tc.cns {
if i >= r.len() {
t.Errorf("missing cert %v @ %v", cn, i)
} else if r.mustCert(t, i).Subject.CommonName != cn {
fmt.Printf("%#v\n", r.mustCert(t, 0).Subject)
t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn)
}
}
if r.len() > len(tc.cns) {
t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns))
}
})
}
}
// Ensure that "SSL_CERT_DIR" when used as the environment
// variable delimited by colons, allows loadSystemRoots to
// load all the roots from the respective directories.
// See https://golang.org/issue/35325.
func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) {
origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
origCertFiles := certFiles[:]
// To prevent any other certs from being loaded in
// through "SSL_CERT_FILE" or from known "certFiles",
// clear them all, and they'll be reverting on defer.
certFiles = certFiles[:0]
os.Setenv(certFileEnv, "")
defer func() {
certFiles = origCertFiles[:]
os.Setenv(certDirEnv, origDir)
os.Setenv(certFileEnv, origFile)
}()
tmpDir := t.TempDir()
rootPEMs := []string{
gtsRoot,
googleLeaf,
startComRoot,
}
var certDirs []string
for i, certPEM := range rootPEMs {
certDir := filepath.Join(tmpDir, fmt.Sprintf("cert-%d", i))
if err := os.MkdirAll(certDir, 0755); err != nil {
t.Fatalf("Failed to create certificate dir: %v", err)
}
certOutFile := filepath.Join(certDir, "cert.crt")
if err := os.WriteFile(certOutFile, []byte(certPEM), 0655); err != nil {
t.Fatalf("Failed to write certificate to file: %v", err)
}
certDirs = append(certDirs, certDir)
}
// Sanity check: the number of certDirs should be equal to the number of roots.
if g, w := len(certDirs), len(rootPEMs); g != w {
t.Fatalf("Failed sanity check: len(certsDir)=%d is not equal to len(rootsPEMS)=%d", g, w)
}
// Now finally concatenate them with a colon.
colonConcatCertDirs := strings.Join(certDirs, ":")
os.Setenv(certDirEnv, colonConcatCertDirs)
gotPool, err := loadSystemRoots()
if err != nil {
t.Fatalf("Failed to load system roots: %v", err)
}
subjects := gotPool.Subjects()
// We expect exactly len(rootPEMs) subjects back.
if g, w := len(subjects), len(rootPEMs); g != w {
t.Fatalf("Invalid number of subjects: got %d want %d", g, w)
}
wantPool := NewCertPool()
for _, certPEM := range rootPEMs {
wantPool.AppendCertsFromPEM([]byte(certPEM))
}
strCertPool := func(p *CertPool) string {
return string(bytes.Join(p.Subjects(), []byte("\n")))
}
if !certPoolEqual(gotPool, wantPool) {
g, w := strCertPool(gotPool), strCertPool(wantPool)
t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
}
}
func TestReadUniqueDirectoryEntries(t *testing.T) {
tmp := t.TempDir()
temp := func(base string) string { return filepath.Join(tmp, base) }
if f, err := os.Create(temp("file")); err != nil {
t.Fatal(err)
} else {
f.Close()
}
if err := os.Symlink("target-in", temp("link-in")); err != nil {
t.Fatal(err)
}
if err := os.Symlink("../target-out", temp("link-out")); err != nil {
t.Fatal(err)
}
got, err := readUniqueDirectoryEntries(tmp)
if err != nil {
t.Fatal(err)
}
gotNames := []string{}
for _, fi := range got {
gotNames = append(gotNames, fi.Name())
}
wantNames := []string{"file", "link-out"}
if !slices.Equal(gotNames, wantNames) {
t.Errorf("got %q; want %q", gotNames, wantNames)
}
}