1
15
16
17 package gmtls
18
19 import (
20 "crypto"
21 "crypto/ecdsa"
22 "crypto/rsa"
23 "crypto/x509"
24 "encoding/pem"
25 "errors"
26 "fmt"
27 "io/ioutil"
28 "net"
29 "strings"
30 "time"
31
32 "github.com/tjfoc/gmsm/sm2"
33 X "github.com/tjfoc/gmsm/x509"
34 )
35
36
37
38
39
40 func Server(conn net.Conn, config *Config) *Conn {
41 return &Conn{conn: conn, config: config}
42 }
43
44
45
46
47
48 func Client(conn net.Conn, config *Config) *Conn {
49 return &Conn{conn: conn, config: config, isClient: true}
50 }
51
52
53 type listener struct {
54 net.Listener
55 config *Config
56 }
57
58
59
60 func (l *listener) Accept() (net.Conn, error) {
61 c, err := l.Listener.Accept()
62 if err != nil {
63 return nil, err
64 }
65 return Server(c, l.config), nil
66 }
67
68
69
70
71
72 func NewListener(inner net.Listener, config *Config) net.Listener {
73 l := new(listener)
74 l.Listener = inner
75 l.config = config
76 return l
77 }
78
79
80
81
82
83 func Listen(network, laddr string, config *Config) (net.Listener, error) {
84 if config == nil || (len(config.Certificates) == 0 && config.GetCertificate == nil) {
85 return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
86 }
87 l, err := net.Listen(network, laddr)
88 if err != nil {
89 return nil, err
90 }
91 return NewListener(l, config), nil
92 }
93
94 type timeoutError struct{}
95
96 func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
97 func (timeoutError) Timeout() bool { return true }
98 func (timeoutError) Temporary() bool { return true }
99
100
101
102
103
104
105
106
107 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
108
109
110
111 timeout := dialer.Timeout
112
113 if !dialer.Deadline.IsZero() {
114
115 deadlineTimeout := dialer.Deadline.Sub(time.Now())
116 if timeout == 0 || deadlineTimeout < timeout {
117 timeout = deadlineTimeout
118 }
119 }
120
121 var errChannel chan error
122
123 if timeout != 0 {
124 errChannel = make(chan error, 2)
125 time.AfterFunc(timeout, func() {
126 errChannel <- timeoutError{}
127 })
128 }
129
130 rawConn, err := dialer.Dial(network, addr)
131 if err != nil {
132 return nil, err
133 }
134
135 colonPos := strings.LastIndex(addr, ":")
136 if colonPos == -1 {
137 colonPos = len(addr)
138 }
139 hostname := addr[:colonPos]
140
141 if config == nil {
142 config = defaultConfig()
143 }
144
145
146 if config.ServerName == "" {
147
148 c := config.Clone()
149 c.ServerName = hostname
150 config = c
151 }
152
153 conn := Client(rawConn, config)
154
155 if timeout == 0 {
156 err = conn.Handshake()
157 } else {
158 go func() {
159 errChannel <- conn.Handshake()
160 }()
161
162 err = <-errChannel
163 }
164
165 if err != nil {
166 rawConn.Close()
167 return nil, err
168 }
169
170 return conn, nil
171 }
172
173
174
175
176
177
178
179 func Dial(network, addr string, config *Config) (*Conn, error) {
180 return DialWithDialer(new(net.Dialer), network, addr, config)
181 }
182
183
184
185
186
187
188 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
189 certPEMBlock, err := ioutil.ReadFile(certFile)
190 if err != nil {
191 return Certificate{}, err
192 }
193 keyPEMBlock, err := ioutil.ReadFile(keyFile)
194 if err != nil {
195 return Certificate{}, err
196 }
197 return X509KeyPair(certPEMBlock, keyPEMBlock)
198 }
199
200
201
202
203 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
204 fail := func(err error) (Certificate, error) { return Certificate{}, err }
205
206 var cert Certificate
207 var skippedBlockTypes []string
208 for {
209 var certDERBlock *pem.Block
210 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
211 if certDERBlock == nil {
212 break
213 }
214 if certDERBlock.Type == "CERTIFICATE" {
215 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
216 } else {
217 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
218 }
219 }
220
221 if len(cert.Certificate) == 0 {
222 if len(skippedBlockTypes) == 0 {
223 return fail(errors.New("tls: failed to find any PEM data in certificate input"))
224 }
225 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
226 return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
227 }
228 return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
229 }
230
231 skippedBlockTypes = skippedBlockTypes[:0]
232 var keyDERBlock *pem.Block
233 for {
234 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
235 if keyDERBlock == nil {
236 if len(skippedBlockTypes) == 0 {
237 return fail(errors.New("tls: failed to find any PEM data in key input"))
238 }
239 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
240 return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
241 }
242 return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
243 }
244 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
245 break
246 }
247 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
248 }
249
250 var err error
251 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
252 if err != nil {
253 return fail(err)
254 }
255
256
257
258 x509Cert, err := X.ParseCertificate(cert.Certificate[0])
259 if err != nil {
260 return fail(err)
261 }
262
263 switch pub := x509Cert.PublicKey.(type) {
264 case *rsa.PublicKey:
265 priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
266 if !ok {
267 return fail(errors.New("tls: private key type does not match public key type"))
268 }
269 if pub.N.Cmp(priv.N) != 0 {
270 return fail(errors.New("tls: private key does not match public key"))
271 }
272 case *ecdsa.PublicKey:
273 pub, _ = x509Cert.PublicKey.(*ecdsa.PublicKey)
274 switch pub.Curve {
275 case sm2.P256Sm2():
276 priv, ok := cert.PrivateKey.(*sm2.PrivateKey)
277 if !ok {
278 return fail(errors.New("tls: sm2 private key type does not match public key type"))
279 }
280 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
281 return fail(errors.New("tls: sm2 private key does not match public key"))
282 }
283 default:
284 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
285 if !ok {
286 return fail(errors.New("tls: private key type does not match public key type"))
287 }
288 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
289 return fail(errors.New("tls: private key does not match public key"))
290 }
291 }
292 default:
293 return fail(errors.New("tls: unknown public key algorithm"))
294 }
295 return cert, nil
296 }
297
298 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
299 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
300 return key, nil
301 }
302 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
303 switch key := key.(type) {
304 case *rsa.PrivateKey, *ecdsa.PrivateKey:
305 return key, nil
306 default:
307 return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
308 }
309 }
310 if key, err := X.ParsePKCS8UnecryptedPrivateKey(der); err == nil {
311 return key, nil
312 }
313 return nil, errors.New("tls: failed to parse private key")
314 }
315
View as plain text