1
2
3
4
5
6
7
8 package x509
9
10 import (
11 "fmt"
12 "os"
13 "testing"
14 )
15
16 const (
17 testDir = "testdata"
18 testDirCN = "test-dir"
19 testFile = "test-file.crt"
20 testFileCN = "test-file"
21 testMissing = "missing"
22 )
23
24 func TestEnvVars(t *testing.T) {
25 testCases := []struct {
26 name string
27 fileEnv string
28 dirEnv string
29 files []string
30 dirs []string
31 cns []string
32 }{
33 {
34
35 name: "override-defaults",
36 fileEnv: testMissing,
37 dirEnv: testMissing,
38 files: []string{testFile},
39 dirs: []string{testDir},
40 cns: nil,
41 },
42 {
43
44 name: "file",
45 fileEnv: testFile,
46 dirEnv: "",
47 files: nil,
48 dirs: nil,
49 cns: []string{testFileCN},
50 },
51 {
52
53 name: "dir",
54 fileEnv: "",
55 dirEnv: testDir,
56 files: nil,
57 dirs: nil,
58 cns: []string{testDirCN},
59 },
60 {
61
62 name: "file+dir",
63 fileEnv: testFile,
64 dirEnv: testDir,
65 files: nil,
66 dirs: nil,
67 cns: []string{testFileCN, testDirCN},
68 },
69 {
70
71 name: "empty-fall-through",
72 fileEnv: "",
73 dirEnv: "",
74 files: []string{testFile},
75 dirs: []string{testDir},
76 cns: []string{testFileCN, testDirCN},
77 },
78 }
79
80
81 origCertFiles, origCertDirectories := certFiles, certDirectories
82 origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
83 defer func() {
84 certFiles = origCertFiles
85 certDirectories = origCertDirectories
86 os.Setenv(certFileEnv, origFile)
87 os.Setenv(certDirEnv, origDir)
88 }()
89
90 for _, tc := range testCases {
91 t.Run(tc.name, func(t *testing.T) {
92 if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil {
93 t.Fatalf("setenv %q failed: %v", certFileEnv, err)
94 }
95 if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil {
96 t.Fatalf("setenv %q failed: %v", certDirEnv, err)
97 }
98
99 certFiles, certDirectories = tc.files, tc.dirs
100
101 r, err := loadSystemRoots()
102 if err != nil {
103 t.Fatal("unexpected failure:", err)
104 }
105
106 if r == nil {
107 t.Fatal("nil roots")
108 }
109
110
111 for i, cn := range tc.cns {
112 if i >= len(r.certs) {
113 t.Errorf("missing cert %v @ %v", cn, i)
114 } else if r.certs[i].Subject.CommonName != cn {
115 fmt.Printf("%#v\n", r.certs[0].Subject)
116 t.Errorf("unexpected cert common name %q, want %q", r.certs[i].Subject.CommonName, cn)
117 }
118 }
119 if len(r.certs) > len(tc.cns) {
120 t.Errorf("got %v certs, which is more than %v wanted", len(r.certs), len(tc.cns))
121 }
122 })
123 }
124 }
125
View as plain text