1
2
3
4
5
6
7 package tlsconfig
8
9 import (
10 "crypto/tls"
11 "crypto/x509"
12 "encoding/pem"
13 "errors"
14 "fmt"
15 "os"
16 )
17
18
19 type Options struct {
20 CAFile string
21
22
23
24
25 CertFile string
26 KeyFile string
27
28
29 InsecureSkipVerify bool
30
31 ClientAuth tls.ClientAuthType
32
33
34
35 ExclusiveRootPools bool
36 MinVersion uint16
37
38
39
40
41
42
43
44 Passphrase string
45 }
46
47
48 var acceptedCBCCiphers = []uint16{
49 tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
50 tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
51 tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
52 tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
53 }
54
55
56
57
58 var DefaultServerAcceptedCiphers = append(clientCipherSuites, acceptedCBCCiphers...)
59
60
61 func ServerDefault(ops ...func(*tls.Config)) *tls.Config {
62 tlsConfig := &tls.Config{
63
64 MinVersion: tls.VersionTLS12,
65 PreferServerCipherSuites: true,
66 CipherSuites: DefaultServerAcceptedCiphers,
67 }
68
69 for _, op := range ops {
70 op(tlsConfig)
71 }
72
73 return tlsConfig
74 }
75
76
77 func ClientDefault(ops ...func(*tls.Config)) *tls.Config {
78 tlsConfig := &tls.Config{
79
80 MinVersion: tls.VersionTLS12,
81 CipherSuites: clientCipherSuites,
82 }
83
84 for _, op := range ops {
85 op(tlsConfig)
86 }
87
88 return tlsConfig
89 }
90
91
92 func certPool(caFile string, exclusivePool bool) (*x509.CertPool, error) {
93
94 var (
95 certPool *x509.CertPool
96 err error
97 )
98 if exclusivePool {
99 certPool = x509.NewCertPool()
100 } else {
101 certPool, err = SystemCertPool()
102 if err != nil {
103 return nil, fmt.Errorf("failed to read system certificates: %v", err)
104 }
105 }
106 pemData, err := os.ReadFile(caFile)
107 if err != nil {
108 return nil, fmt.Errorf("could not read CA certificate %q: %v", caFile, err)
109 }
110 if !certPool.AppendCertsFromPEM(pemData) {
111 return nil, fmt.Errorf("failed to append certificates from PEM file: %q", caFile)
112 }
113 return certPool, nil
114 }
115
116
117
118 var allTLSVersions = map[uint16]struct{}{
119 tls.VersionTLS10: {},
120 tls.VersionTLS11: {},
121 tls.VersionTLS12: {},
122 tls.VersionTLS13: {},
123 }
124
125
126 func isValidMinVersion(version uint16) bool {
127 _, ok := allTLSVersions[version]
128 return ok
129 }
130
131
132
133 func adjustMinVersion(options Options, config *tls.Config) error {
134 if options.MinVersion > 0 {
135 if !isValidMinVersion(options.MinVersion) {
136 return fmt.Errorf("invalid minimum TLS version: %x", options.MinVersion)
137 }
138 if options.MinVersion < config.MinVersion {
139 return fmt.Errorf("requested minimum TLS version is too low. Should be at-least: %x", config.MinVersion)
140 }
141 config.MinVersion = options.MinVersion
142 }
143
144 return nil
145 }
146
147
148
149
150
151
152
153
154 func IsErrEncryptedKey(err error) bool {
155 return errors.Is(err, x509.IncorrectPasswordError)
156 }
157
158
159
160
161 func getPrivateKey(keyBytes []byte, passphrase string) ([]byte, error) {
162
163 pemBlock, _ := pem.Decode(keyBytes)
164 if pemBlock == nil {
165 return nil, fmt.Errorf("no valid private key found")
166 }
167
168 var err error
169 if x509.IsEncryptedPEMBlock(pemBlock) {
170 keyBytes, err = x509.DecryptPEMBlock(pemBlock, []byte(passphrase))
171 if err != nil {
172 return nil, fmt.Errorf("private key is encrypted, but could not decrypt it: %w", err)
173 }
174 keyBytes = pem.EncodeToMemory(&pem.Block{Type: pemBlock.Type, Bytes: keyBytes})
175 }
176
177 return keyBytes, nil
178 }
179
180
181
182
183 func getCert(options Options) ([]tls.Certificate, error) {
184 if options.CertFile == "" && options.KeyFile == "" {
185 return nil, nil
186 }
187
188 cert, err := os.ReadFile(options.CertFile)
189 if err != nil {
190 return nil, err
191 }
192
193 prKeyBytes, err := os.ReadFile(options.KeyFile)
194 if err != nil {
195 return nil, err
196 }
197
198 prKeyBytes, err = getPrivateKey(prKeyBytes, options.Passphrase)
199 if err != nil {
200 return nil, err
201 }
202
203 tlsCert, err := tls.X509KeyPair(cert, prKeyBytes)
204 if err != nil {
205 return nil, err
206 }
207
208 return []tls.Certificate{tlsCert}, nil
209 }
210
211
212 func Client(options Options) (*tls.Config, error) {
213 tlsConfig := ClientDefault()
214 tlsConfig.InsecureSkipVerify = options.InsecureSkipVerify
215 if !options.InsecureSkipVerify && options.CAFile != "" {
216 CAs, err := certPool(options.CAFile, options.ExclusiveRootPools)
217 if err != nil {
218 return nil, err
219 }
220 tlsConfig.RootCAs = CAs
221 }
222
223 tlsCerts, err := getCert(options)
224 if err != nil {
225 return nil, fmt.Errorf("could not load X509 key pair: %w", err)
226 }
227 tlsConfig.Certificates = tlsCerts
228
229 if err := adjustMinVersion(options, tlsConfig); err != nil {
230 return nil, err
231 }
232
233 return tlsConfig, nil
234 }
235
236
237 func Server(options Options) (*tls.Config, error) {
238 tlsConfig := ServerDefault()
239 tlsConfig.ClientAuth = options.ClientAuth
240 tlsCert, err := tls.LoadX509KeyPair(options.CertFile, options.KeyFile)
241 if err != nil {
242 if os.IsNotExist(err) {
243 return nil, fmt.Errorf("could not load X509 key pair (cert: %q, key: %q): %v", options.CertFile, options.KeyFile, err)
244 }
245 return nil, fmt.Errorf("error reading X509 key pair - make sure the key is not encrypted (cert: %q, key: %q): %v", options.CertFile, options.KeyFile, err)
246 }
247 tlsConfig.Certificates = []tls.Certificate{tlsCert}
248 if options.ClientAuth >= tls.VerifyClientCertIfGiven && options.CAFile != "" {
249 CAs, err := certPool(options.CAFile, options.ExclusiveRootPools)
250 if err != nil {
251 return nil, err
252 }
253 tlsConfig.ClientCAs = CAs
254 }
255
256 if err := adjustMinVersion(options, tlsConfig); err != nil {
257 return nil, err
258 }
259
260 return tlsConfig, nil
261 }
262
View as plain text