1
2
3
4
5
6
7
8
9
10
11
12
13
14 package config
15
16 import (
17 "bytes"
18 "crypto/tls"
19 "encoding/json"
20 "fmt"
21 "os"
22 "path/filepath"
23 "reflect"
24 "strings"
25 "testing"
26
27 "gopkg.in/yaml.v2"
28 )
29
30
31 func LoadTLSConfig(filename string) (*tls.Config, error) {
32 content, err := os.ReadFile(filename)
33 if err != nil {
34 return nil, err
35 }
36 cfg := TLSConfig{}
37 switch filepath.Ext(filename) {
38 case ".yml":
39 if err = yaml.UnmarshalStrict(content, &cfg); err != nil {
40 return nil, err
41 }
42 case ".json":
43 decoder := json.NewDecoder(bytes.NewReader(content))
44 decoder.DisallowUnknownFields()
45 if err = decoder.Decode(&cfg); err != nil {
46 return nil, err
47 }
48 default:
49 return nil, fmt.Errorf("Unknown extension: %s", filepath.Ext(filename))
50 }
51 return NewTLSConfig(&cfg)
52 }
53
54 var expectedTLSConfigs = []struct {
55 filename string
56 config *tls.Config
57 }{
58 {
59 filename: "tls_config.empty.good.json",
60 config: &tls.Config{},
61 },
62 {
63 filename: "tls_config.insecure.good.json",
64 config: &tls.Config{InsecureSkipVerify: true},
65 },
66 {
67 filename: "tls_config.tlsversion.good.json",
68 config: &tls.Config{MinVersion: tls.VersionTLS11},
69 },
70 {
71 filename: "tls_config.max_version.good.json",
72 config: &tls.Config{MaxVersion: tls.VersionTLS12},
73 },
74 {
75 filename: "tls_config.empty.good.yml",
76 config: &tls.Config{},
77 },
78 {
79 filename: "tls_config.insecure.good.yml",
80 config: &tls.Config{InsecureSkipVerify: true},
81 },
82 {
83 filename: "tls_config.tlsversion.good.yml",
84 config: &tls.Config{MinVersion: tls.VersionTLS11},
85 },
86 {
87 filename: "tls_config.max_version.good.yml",
88 config: &tls.Config{MaxVersion: tls.VersionTLS12},
89 },
90 {
91 filename: "tls_config.max_and_min_version.good.yml",
92 config: &tls.Config{MaxVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS11},
93 },
94 {
95 filename: "tls_config.max_and_min_version_same.good.yml",
96 config: &tls.Config{MaxVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12},
97 },
98 }
99
100 func TestValidTLSConfig(t *testing.T) {
101 for _, cfg := range expectedTLSConfigs {
102 got, err := LoadTLSConfig("testdata/" + cfg.filename)
103 if err != nil {
104 t.Fatalf("Error parsing %s: %s", cfg.filename, err)
105 }
106
107 got.GetClientCertificate = nil
108 if !reflect.DeepEqual(got, cfg.config) {
109 t.Fatalf("%v: unexpected config result: \n\n%v\n expected\n\n%v", cfg.filename, got, cfg.config)
110 }
111 }
112 }
113
114 var invalidTLSConfigs = []struct {
115 filename string
116 errMsg string
117 }{
118 {
119 filename: "tls_config.max_and_min_version.bad.yml",
120 errMsg: "tls_config.max_version must be greater than or equal to tls_config.min_version if both are specified",
121 },
122 }
123
124 func TestInvalidTLSConfig(t *testing.T) {
125 for _, ee := range invalidTLSConfigs {
126 _, err := LoadTLSConfig("testdata/" + ee.filename)
127 if err == nil {
128 t.Error("Expected error with config but got none")
129 continue
130 }
131 if !strings.Contains(err.Error(), ee.errMsg) {
132 t.Errorf("Expected error for invalid HTTP client configuration to contain %q but got: %s", ee.errMsg, err)
133 }
134 }
135 }
136
137 func TestTLSVersionStringer(t *testing.T) {
138 if s := (TLSVersion)(tls.VersionTLS13); s.String() != "TLS13" {
139 t.Fatalf("tls.VersionTLS13 string should be TLS13, got %s", s.String())
140 }
141 }
142
143 func TestTLSVersionMarshalYAML(t *testing.T) {
144 tests := []struct {
145 input TLSVersion
146 expected string
147 err error
148 }{
149 {
150 input: TLSVersions["TLS13"],
151 expected: "TLS13\n",
152 err: nil,
153 },
154 {
155 input: TLSVersions["TLS10"],
156 expected: "TLS10\n",
157 err: nil,
158 },
159 {
160 input: TLSVersion(999),
161 expected: "",
162 err: fmt.Errorf("unknown TLS version: 999"),
163 },
164 }
165
166 for _, test := range tests {
167 t.Run(fmt.Sprintf("MarshalYAML(%d)", test.input), func(t *testing.T) {
168 actualBytes, err := yaml.Marshal(&test.input)
169 if err != nil {
170 if test.err == nil || err.Error() != test.err.Error() {
171 t.Fatalf("error %v, expected %v", err, test.err)
172 }
173 return
174 }
175 actual := string(actualBytes)
176 if actual != test.expected {
177 t.Fatalf("returned %s, expected %s", actual, test.expected)
178 }
179 })
180 }
181 }
182
183 func TestTLSVersionMarshalJSON(t *testing.T) {
184 tests := []struct {
185 input TLSVersion
186 expected string
187 err error
188 }{
189 {
190 input: TLSVersions["TLS13"],
191 expected: `"TLS13"`,
192 err: nil,
193 },
194 {
195 input: TLSVersions["TLS10"],
196 expected: `"TLS10"`,
197 err: nil,
198 },
199 {
200 input: TLSVersion(999),
201 expected: "",
202 err: fmt.Errorf("unknown TLS version: 999"),
203 },
204 }
205
206 for _, test := range tests {
207 t.Run(fmt.Sprintf("MarshalJSON(%d)", test.input), func(t *testing.T) {
208 actualBytes, err := json.Marshal(&test.input)
209 if err != nil {
210 if test.err == nil || !strings.HasSuffix(err.Error(), test.err.Error()) {
211 t.Fatalf("error %v, expected %v", err, test.err)
212 }
213 return
214 }
215 actual := string(actualBytes)
216 if actual != test.expected {
217 t.Fatalf("returned %s, expected %s", actual, test.expected)
218 }
219 })
220 }
221 }
222
View as plain text