...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package authdb
16
17 import (
18 "bytes"
19 "context"
20 "crypto/hmac"
21 "crypto/sha1"
22 "encoding/base64"
23 "encoding/json"
24 "errors"
25 "fmt"
26 "strconv"
27 "time"
28
29 "golang.org/x/crypto/pbkdf2"
30 )
31
32
33 type UserStore interface {
34
35
36
37 Validate(ctx context.Context, username, password string) (user *UserContext, err error)
38
39
40 UserCtx(ctx context.Context, username string) (user *UserContext, err error)
41 }
42
43
44
45 const PBKDF2KeyLength = 20
46
47
48 const SchemePBKDF2 = "pbkdf2"
49
50
51
52 type UserContext struct {
53 Database string `json:"db,omitempty"`
54 Name string `json:"name"`
55 Roles []string `json:"roles"`
56
57 Salt string `json:"-"`
58 }
59
60
61 func ValidatePBKDF2(password, salt, derivedKey string, iterations int) bool {
62 hash := fmt.Sprintf("%x", pbkdf2.Key([]byte(password), []byte(salt), iterations, PBKDF2KeyLength, sha1.New))
63 return hash == derivedKey
64 }
65
66
67
68 func CreateAuthToken(name, salt, secret string, time int64) string {
69 if secret == "" {
70 panic("secret must be set")
71 }
72 if salt == "" {
73 panic("salt must be set")
74 }
75 sessionData := fmt.Sprintf("%s:%X", name, time)
76 h := hmac.New(sha1.New, []byte(secret+salt))
77 _, _ = h.Write([]byte(sessionData))
78 hashData := string(h.Sum(nil))
79 return base64.RawURLEncoding.EncodeToString([]byte(sessionData + ":" + hashData))
80 }
81
82
83 func (c *UserContext) MarshalJSON() ([]byte, error) {
84 roles := c.Roles
85 if roles == nil {
86 roles = []string{}
87 }
88 output := map[string]interface{}{
89 "roles": roles,
90 }
91 if c.Database != "" {
92 output["db"] = c.Database
93 }
94 if c.Name != "" {
95 output["name"] = c.Name
96 } else {
97 output["name"] = nil
98 }
99 return json.Marshal(output)
100 }
101
102
103
104
105 func DecodeAuthToken(token string) (username string, created time.Time, err error) {
106 payload, err := base64.RawURLEncoding.DecodeString(token)
107 if err != nil {
108 return username, created, err
109 }
110 const partCount = 3
111 parts := bytes.SplitN(payload, []byte(":"), partCount)
112 if len(parts) < partCount {
113 return username, created, errors.New("invalid payload")
114 }
115 seconds, err := strconv.ParseInt(string(parts[1]), 16, 64)
116 if err != nil {
117 return username, created, fmt.Errorf("invalid timestamp '%s'", string(parts[1]))
118 }
119 return string(parts[0]), time.Unix(seconds, 0), nil
120 }
121
View as plain text