...
1
15
16 package gmcredentials
17
18 import (
19 "errors"
20 "fmt"
21 "io/ioutil"
22 "net"
23 "strings"
24
25 "github.com/tjfoc/gmsm/gmtls"
26 "github.com/tjfoc/gmsm/x509"
27 "golang.org/x/net/context"
28 "google.golang.org/grpc/credentials"
29 )
30
31 var (
32
33 alpnProtoStr = []string{"h2"}
34 )
35
36
37
38 type PerRPCCredentials interface {
39
40
41
42
43
44
45
46
47 GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
48
49
50 RequireTransportSecurity() bool
51 }
52
53
54
55 type ProtocolInfo struct {
56
57 ProtocolVersion string
58
59 SecurityProtocol string
60
61 SecurityVersion string
62
63 ServerName string
64 }
65
66
67 type AuthInfo interface {
68 AuthType() string
69 }
70
71 var (
72
73
74 ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
75 )
76
77
78
79 type TLSInfo struct {
80 State gmtls.ConnectionState
81 }
82
83
84 func (t TLSInfo) AuthType() string {
85 return "tls"
86 }
87
88
89 type tlsCreds struct {
90
91 config *gmtls.Config
92 }
93
94 func (c tlsCreds) Info() credentials.ProtocolInfo {
95 return credentials.ProtocolInfo{
96 SecurityProtocol: "tls",
97 SecurityVersion: "1.2",
98 ServerName: c.config.ServerName,
99 }
100 }
101
102 func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
103
104 cfg := cloneTLSConfig(c.config)
105 if cfg.ServerName == "" {
106 colonPos := strings.LastIndex(addr, ":")
107 if colonPos == -1 {
108 colonPos = len(addr)
109 }
110 cfg.ServerName = addr[:colonPos]
111 }
112 conn := gmtls.Client(rawConn, cfg)
113 errChannel := make(chan error, 1)
114 go func() {
115 errChannel <- conn.Handshake()
116 }()
117 select {
118 case err := <-errChannel:
119 if err != nil {
120 return nil, nil, err
121 }
122 case <-ctx.Done():
123 return nil, nil, ctx.Err()
124 }
125 return conn, TLSInfo{conn.ConnectionState()}, nil
126 }
127
128 func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
129 conn := gmtls.Server(rawConn, c.config)
130 if err := conn.Handshake(); err != nil {
131 return nil, nil, err
132 }
133 return conn, TLSInfo{conn.ConnectionState()}, nil
134 }
135
136 func (c *tlsCreds) Clone() credentials.TransportCredentials {
137 return NewTLS(c.config)
138 }
139
140 func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
141 c.config.ServerName = serverNameOverride
142 return nil
143 }
144
145
146 func NewTLS(c *gmtls.Config) credentials.TransportCredentials {
147 tc := &tlsCreds{cloneTLSConfig(c)}
148 tc.config.NextProtos = alpnProtoStr
149 return tc
150 }
151
152
153
154
155 func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) credentials.TransportCredentials {
156 return NewTLS(&gmtls.Config{GMSupport: &gmtls.GMSupport{}, ServerName: serverNameOverride, RootCAs: cp})
157 }
158
159
160
161
162 func NewClientTLSFromFile(certFile, serverNameOverride string) (credentials.TransportCredentials, error) {
163 b, err := ioutil.ReadFile(certFile)
164 if err != nil {
165 return nil, err
166 }
167 cp := x509.NewCertPool()
168 if !cp.AppendCertsFromPEM(b) {
169 return nil, fmt.Errorf("credentials: failed to append certificates")
170 }
171 return NewTLS(&gmtls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
172 }
173
174
175 func NewServerTLSFromCert(cert *gmtls.Certificate) credentials.TransportCredentials {
176 return NewTLS(&gmtls.Config{Certificates: []gmtls.Certificate{*cert}})
177 }
178
179
180
181 func NewServerTLSFromFile(certFile, keyFile string) (credentials.TransportCredentials, error) {
182 cert, err := gmtls.LoadX509KeyPair(certFile, keyFile)
183 if err != nil {
184 return nil, err
185 }
186 return NewTLS(&gmtls.Config{Certificates: []gmtls.Certificate{cert}}), nil
187 }
188
View as plain text