1 package jwt_test
2
3
4
5
6 import (
7 "bytes"
8 "crypto/rsa"
9 "fmt"
10 "io"
11 "log"
12 "net"
13 "net/http"
14 "net/url"
15 "os"
16 "strings"
17 "time"
18
19 "github.com/golang-jwt/jwt/v4"
20 "github.com/golang-jwt/jwt/v4/request"
21 )
22
23
24 const (
25 privKeyPath = "test/sample_key"
26 pubKeyPath = "test/sample_key.pub"
27 )
28
29 var (
30 verifyKey *rsa.PublicKey
31 signKey *rsa.PrivateKey
32 serverPort int
33 )
34
35
36 func init() {
37 signBytes, err := os.ReadFile(privKeyPath)
38 fatal(err)
39
40 signKey, err = jwt.ParseRSAPrivateKeyFromPEM(signBytes)
41 fatal(err)
42
43 verifyBytes, err := os.ReadFile(pubKeyPath)
44 fatal(err)
45
46 verifyKey, err = jwt.ParseRSAPublicKeyFromPEM(verifyBytes)
47 fatal(err)
48
49 http.HandleFunc("/authenticate", authHandler)
50 http.HandleFunc("/restricted", restrictedHandler)
51
52
53 listener, err := net.ListenTCP("tcp", &net.TCPAddr{})
54 fatal(err)
55 serverPort = listener.Addr().(*net.TCPAddr).Port
56
57 log.Println("Listening...")
58 go func() {
59 fatal(http.Serve(listener, nil))
60 }()
61 }
62
63 func fatal(err error) {
64 if err != nil {
65 log.Fatal(err)
66 }
67 }
68
69
70 type CustomerInfo struct {
71 Name string
72 Kind string
73 }
74
75 type CustomClaimsExample struct {
76 jwt.RegisteredClaims
77 TokenType string
78 CustomerInfo
79 }
80
81 func Example_getTokenViaHTTP() {
82
83 res, err := http.PostForm(fmt.Sprintf("http://localhost:%v/authenticate", serverPort), url.Values{
84 "user": {"test"},
85 "pass": {"known"},
86 })
87 if err != nil {
88 fatal(err)
89 }
90
91 if res.StatusCode != 200 {
92 fmt.Println("Unexpected status code", res.StatusCode)
93 }
94
95
96 buf := new(bytes.Buffer)
97 io.Copy(buf, res.Body)
98 res.Body.Close()
99 tokenString := strings.TrimSpace(buf.String())
100
101
102 token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token) (interface{}, error) {
103
104
105 return verifyKey, nil
106 })
107 fatal(err)
108
109 claims := token.Claims.(*CustomClaimsExample)
110 fmt.Println(claims.CustomerInfo.Name)
111
112
113 }
114
115 func Example_useTokenViaHTTP() {
116
117
118
119
120 token, err := createToken("foo")
121 fatal(err)
122
123
124 req, err := http.NewRequest("GET", fmt.Sprintf("http://localhost:%v/restricted", serverPort), nil)
125 fatal(err)
126 req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token))
127 res, err := http.DefaultClient.Do(req)
128 fatal(err)
129
130
131 buf := new(bytes.Buffer)
132 io.Copy(buf, res.Body)
133 res.Body.Close()
134 fmt.Println(buf.String())
135
136
137 }
138
139 func createToken(user string) (string, error) {
140
141 t := jwt.New(jwt.GetSigningMethod("RS256"))
142
143
144 t.Claims = &CustomClaimsExample{
145 jwt.RegisteredClaims{
146
147
148 ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
149 },
150 "level1",
151 CustomerInfo{user, "human"},
152 }
153
154
155 return t.SignedString(signKey)
156 }
157
158
159 func authHandler(w http.ResponseWriter, r *http.Request) {
160
161 if r.Method != "POST" {
162 w.WriteHeader(http.StatusBadRequest)
163 fmt.Fprintln(w, "No POST", r.Method)
164 return
165 }
166
167 user := r.FormValue("user")
168 pass := r.FormValue("pass")
169
170 log.Printf("Authenticate: user[%s] pass[%s]\n", user, pass)
171
172
173 if user != "test" || pass != "known" {
174 w.WriteHeader(http.StatusForbidden)
175 fmt.Fprintln(w, "Wrong info")
176 return
177 }
178
179 tokenString, err := createToken(user)
180 if err != nil {
181 w.WriteHeader(http.StatusInternalServerError)
182 fmt.Fprintln(w, "Sorry, error while Signing Token!")
183 log.Printf("Token Signing error: %v\n", err)
184 return
185 }
186
187 w.Header().Set("Content-Type", "application/jwt")
188 w.WriteHeader(http.StatusOK)
189 fmt.Fprintln(w, tokenString)
190 }
191
192
193 func restrictedHandler(w http.ResponseWriter, r *http.Request) {
194
195 token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) {
196
197
198 return verifyKey, nil
199 }, request.WithClaims(&CustomClaimsExample{}))
200
201
202 if err != nil {
203 w.WriteHeader(http.StatusUnauthorized)
204 fmt.Fprintln(w, "Invalid token:", err)
205 return
206 }
207
208
209 fmt.Fprintln(w, "Welcome,", token.Claims.(*CustomClaimsExample).Name)
210 }
211
View as plain text