...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package auth
16
17 import (
18 "crypto/ecdsa"
19 "crypto/rsa"
20 "fmt"
21 "io/ioutil"
22 "time"
23
24 "github.com/golang-jwt/jwt/v4"
25 )
26
27 const (
28 optSignMethod = "sign-method"
29 optPublicKey = "pub-key"
30 optPrivateKey = "priv-key"
31 optTTL = "ttl"
32 )
33
34 var knownOptions = map[string]bool{
35 optSignMethod: true,
36 optPublicKey: true,
37 optPrivateKey: true,
38 optTTL: true,
39 }
40
41 var (
42
43 DefaultTTL = 5 * time.Minute
44 )
45
46 type jwtOptions struct {
47 SignMethod jwt.SigningMethod
48 PublicKey []byte
49 PrivateKey []byte
50 TTL time.Duration
51 }
52
53
54 func (opts *jwtOptions) ParseWithDefaults(optMap map[string]string) error {
55 if opts.TTL == 0 && optMap[optTTL] == "" {
56 opts.TTL = DefaultTTL
57 }
58
59 return opts.Parse(optMap)
60 }
61
62
63 func (opts *jwtOptions) Parse(optMap map[string]string) error {
64 var err error
65 if ttl := optMap[optTTL]; ttl != "" {
66 opts.TTL, err = time.ParseDuration(ttl)
67 if err != nil {
68 return err
69 }
70 }
71
72 if file := optMap[optPublicKey]; file != "" {
73 opts.PublicKey, err = ioutil.ReadFile(file)
74 if err != nil {
75 return err
76 }
77 }
78
79 if file := optMap[optPrivateKey]; file != "" {
80 opts.PrivateKey, err = ioutil.ReadFile(file)
81 if err != nil {
82 return err
83 }
84 }
85
86
87 method := optMap[optSignMethod]
88 opts.SignMethod = jwt.GetSigningMethod(method)
89 if opts.SignMethod == nil {
90 return ErrInvalidAuthMethod
91 }
92
93 return nil
94 }
95
96
97 func (opts *jwtOptions) Key() (interface{}, error) {
98 switch opts.SignMethod.(type) {
99 case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS:
100 return opts.rsaKey()
101 case *jwt.SigningMethodECDSA:
102 return opts.ecKey()
103 case *jwt.SigningMethodHMAC:
104 return opts.hmacKey()
105 default:
106 return nil, fmt.Errorf("unsupported signing method: %T", opts.SignMethod)
107 }
108 }
109
110 func (opts *jwtOptions) hmacKey() (interface{}, error) {
111 if len(opts.PrivateKey) == 0 {
112 return nil, ErrMissingKey
113 }
114 return opts.PrivateKey, nil
115 }
116
117 func (opts *jwtOptions) rsaKey() (interface{}, error) {
118 var (
119 priv *rsa.PrivateKey
120 pub *rsa.PublicKey
121 err error
122 )
123
124 if len(opts.PrivateKey) > 0 {
125 priv, err = jwt.ParseRSAPrivateKeyFromPEM(opts.PrivateKey)
126 if err != nil {
127 return nil, err
128 }
129 }
130
131 if len(opts.PublicKey) > 0 {
132 pub, err = jwt.ParseRSAPublicKeyFromPEM(opts.PublicKey)
133 if err != nil {
134 return nil, err
135 }
136 }
137
138 if priv == nil {
139 if pub == nil {
140
141 return nil, ErrMissingKey
142 }
143
144 return pub, nil
145 }
146
147
148 if pub != nil && pub.E != priv.E && pub.N.Cmp(priv.N) != 0 {
149 return nil, ErrKeyMismatch
150 }
151
152 return priv, nil
153 }
154
155 func (opts *jwtOptions) ecKey() (interface{}, error) {
156 var (
157 priv *ecdsa.PrivateKey
158 pub *ecdsa.PublicKey
159 err error
160 )
161
162 if len(opts.PrivateKey) > 0 {
163 priv, err = jwt.ParseECPrivateKeyFromPEM(opts.PrivateKey)
164 if err != nil {
165 return nil, err
166 }
167 }
168
169 if len(opts.PublicKey) > 0 {
170 pub, err = jwt.ParseECPublicKeyFromPEM(opts.PublicKey)
171 if err != nil {
172 return nil, err
173 }
174 }
175
176 if priv == nil {
177 if pub == nil {
178
179 return nil, ErrMissingKey
180 }
181
182 return pub, nil
183 }
184
185
186 if pub != nil && pub.Curve != priv.Curve &&
187 pub.X.Cmp(priv.X) != 0 && pub.Y.Cmp(priv.Y) != 0 {
188 return nil, ErrKeyMismatch
189 }
190
191 return priv, nil
192 }
193
View as plain text