|  | // Copyright 2017 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | // +build dragonfly freebsd linux netbsd openbsd solaris | 
|  |  | 
|  | package x509 | 
|  |  | 
|  | import ( | 
|  | "bytes" | 
|  | "fmt" | 
|  | "os" | 
|  | "path/filepath" | 
|  | "reflect" | 
|  | "strings" | 
|  | "testing" | 
|  | ) | 
|  |  | 
|  | const ( | 
|  | testDir     = "testdata" | 
|  | testDirCN   = "test-dir" | 
|  | testFile    = "test-file.crt" | 
|  | testFileCN  = "test-file" | 
|  | testMissing = "missing" | 
|  | ) | 
|  |  | 
|  | func TestEnvVars(t *testing.T) { | 
|  | testCases := []struct { | 
|  | name    string | 
|  | fileEnv string | 
|  | dirEnv  string | 
|  | files   []string | 
|  | dirs    []string | 
|  | cns     []string | 
|  | }{ | 
|  | { | 
|  | // Environment variables override the default locations preventing fall through. | 
|  | name:    "override-defaults", | 
|  | fileEnv: testMissing, | 
|  | dirEnv:  testMissing, | 
|  | files:   []string{testFile}, | 
|  | dirs:    []string{testDir}, | 
|  | cns:     nil, | 
|  | }, | 
|  | { | 
|  | // File environment overrides default file locations. | 
|  | name:    "file", | 
|  | fileEnv: testFile, | 
|  | dirEnv:  "", | 
|  | files:   nil, | 
|  | dirs:    nil, | 
|  | cns:     []string{testFileCN}, | 
|  | }, | 
|  | { | 
|  | // Directory environment overrides default directory locations. | 
|  | name:    "dir", | 
|  | fileEnv: "", | 
|  | dirEnv:  testDir, | 
|  | files:   nil, | 
|  | dirs:    nil, | 
|  | cns:     []string{testDirCN}, | 
|  | }, | 
|  | { | 
|  | // File & directory environment overrides both default locations. | 
|  | name:    "file+dir", | 
|  | fileEnv: testFile, | 
|  | dirEnv:  testDir, | 
|  | files:   nil, | 
|  | dirs:    nil, | 
|  | cns:     []string{testFileCN, testDirCN}, | 
|  | }, | 
|  | { | 
|  | // Environment variable empty / unset uses default locations. | 
|  | name:    "empty-fall-through", | 
|  | fileEnv: "", | 
|  | dirEnv:  "", | 
|  | files:   []string{testFile}, | 
|  | dirs:    []string{testDir}, | 
|  | cns:     []string{testFileCN, testDirCN}, | 
|  | }, | 
|  | } | 
|  |  | 
|  | // Save old settings so we can restore before the test ends. | 
|  | origCertFiles, origCertDirectories := certFiles, certDirectories | 
|  | origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv) | 
|  | defer func() { | 
|  | certFiles = origCertFiles | 
|  | certDirectories = origCertDirectories | 
|  | os.Setenv(certFileEnv, origFile) | 
|  | os.Setenv(certDirEnv, origDir) | 
|  | }() | 
|  |  | 
|  | for _, tc := range testCases { | 
|  | t.Run(tc.name, func(t *testing.T) { | 
|  | if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil { | 
|  | t.Fatalf("setenv %q failed: %v", certFileEnv, err) | 
|  | } | 
|  | if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil { | 
|  | t.Fatalf("setenv %q failed: %v", certDirEnv, err) | 
|  | } | 
|  |  | 
|  | certFiles, certDirectories = tc.files, tc.dirs | 
|  |  | 
|  | r, err := loadSystemRoots() | 
|  | if err != nil { | 
|  | t.Fatal("unexpected failure:", err) | 
|  | } | 
|  |  | 
|  | if r == nil { | 
|  | t.Fatal("nil roots") | 
|  | } | 
|  |  | 
|  | // Verify that the returned certs match, otherwise report where the mismatch is. | 
|  | for i, cn := range tc.cns { | 
|  | if i >= r.len() { | 
|  | t.Errorf("missing cert %v @ %v", cn, i) | 
|  | } else if r.mustCert(t, i).Subject.CommonName != cn { | 
|  | fmt.Printf("%#v\n", r.mustCert(t, 0).Subject) | 
|  | t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn) | 
|  | } | 
|  | } | 
|  | if r.len() > len(tc.cns) { | 
|  | t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns)) | 
|  | } | 
|  | }) | 
|  | } | 
|  | } | 
|  |  | 
|  | // Ensure that "SSL_CERT_DIR" when used as the environment | 
|  | // variable delimited by colons, allows loadSystemRoots to | 
|  | // load all the roots from the respective directories. | 
|  | // See https://golang.org/issue/35325. | 
|  | func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) { | 
|  | origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv) | 
|  | origCertFiles := certFiles[:] | 
|  |  | 
|  | // To prevent any other certs from being loaded in | 
|  | // through "SSL_CERT_FILE" or from known "certFiles", | 
|  | // clear them all, and they'll be reverting on defer. | 
|  | certFiles = certFiles[:0] | 
|  | os.Setenv(certFileEnv, "") | 
|  |  | 
|  | defer func() { | 
|  | certFiles = origCertFiles[:] | 
|  | os.Setenv(certDirEnv, origDir) | 
|  | os.Setenv(certFileEnv, origFile) | 
|  | }() | 
|  |  | 
|  | tmpDir, err := os.MkdirTemp(os.TempDir(), "x509-issue35325") | 
|  | if err != nil { | 
|  | t.Fatalf("Failed to create temporary directory: %v", err) | 
|  | } | 
|  | defer os.RemoveAll(tmpDir) | 
|  |  | 
|  | rootPEMs := []string{ | 
|  | geoTrustRoot, | 
|  | googleLeaf, | 
|  | startComRoot, | 
|  | } | 
|  |  | 
|  | var certDirs []string | 
|  | for i, certPEM := range rootPEMs { | 
|  | certDir := filepath.Join(tmpDir, fmt.Sprintf("cert-%d", i)) | 
|  | if err := os.MkdirAll(certDir, 0755); err != nil { | 
|  | t.Fatalf("Failed to create certificate dir: %v", err) | 
|  | } | 
|  | certOutFile := filepath.Join(certDir, "cert.crt") | 
|  | if err := os.WriteFile(certOutFile, []byte(certPEM), 0655); err != nil { | 
|  | t.Fatalf("Failed to write certificate to file: %v", err) | 
|  | } | 
|  | certDirs = append(certDirs, certDir) | 
|  | } | 
|  |  | 
|  | // Sanity check: the number of certDirs should be equal to the number of roots. | 
|  | if g, w := len(certDirs), len(rootPEMs); g != w { | 
|  | t.Fatalf("Failed sanity check: len(certsDir)=%d is not equal to len(rootsPEMS)=%d", g, w) | 
|  | } | 
|  |  | 
|  | // Now finally concatenate them with a colon. | 
|  | colonConcatCertDirs := strings.Join(certDirs, ":") | 
|  | os.Setenv(certDirEnv, colonConcatCertDirs) | 
|  | gotPool, err := loadSystemRoots() | 
|  | if err != nil { | 
|  | t.Fatalf("Failed to load system roots: %v", err) | 
|  | } | 
|  | subjects := gotPool.Subjects() | 
|  | // We expect exactly len(rootPEMs) subjects back. | 
|  | if g, w := len(subjects), len(rootPEMs); g != w { | 
|  | t.Fatalf("Invalid number of subjects: got %d want %d", g, w) | 
|  | } | 
|  |  | 
|  | wantPool := NewCertPool() | 
|  | for _, certPEM := range rootPEMs { | 
|  | wantPool.AppendCertsFromPEM([]byte(certPEM)) | 
|  | } | 
|  | strCertPool := func(p *CertPool) string { | 
|  | return string(bytes.Join(p.Subjects(), []byte("\n"))) | 
|  | } | 
|  |  | 
|  | if !certPoolEqual(gotPool, wantPool) { | 
|  | g, w := strCertPool(gotPool), strCertPool(wantPool) | 
|  | t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w) | 
|  | } | 
|  | } | 
|  |  | 
|  | func TestReadUniqueDirectoryEntries(t *testing.T) { | 
|  | tmp := t.TempDir() | 
|  | temp := func(base string) string { return filepath.Join(tmp, base) } | 
|  | if f, err := os.Create(temp("file")); err != nil { | 
|  | t.Fatal(err) | 
|  | } else { | 
|  | f.Close() | 
|  | } | 
|  | if err := os.Symlink("target-in", temp("link-in")); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if err := os.Symlink("../target-out", temp("link-out")); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | got, err := readUniqueDirectoryEntries(tmp) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | gotNames := []string{} | 
|  | for _, fi := range got { | 
|  | gotNames = append(gotNames, fi.Name()) | 
|  | } | 
|  | wantNames := []string{"file", "link-out"} | 
|  | if !reflect.DeepEqual(gotNames, wantNames) { | 
|  | t.Errorf("got %q; want %q", gotNames, wantNames) | 
|  | } | 
|  | } |