1 package database
2
3 import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "strings"
8 "time"
9
10 "edge-infra.dev/pkg/edge/iam/config"
11 "edge-infra.dev/pkg/edge/iam/pin"
12
13 "github.com/pkg/errors"
14
15 "golang.org/x/crypto/bcrypt"
16
17 iamErrors "edge-infra.dev/pkg/edge/iam/errors"
18 )
19
20 func (s *Store) SavePIN(ctx context.Context, userID string, pincode string) error {
21 key := keyFrom(KeyPrefixPIN, userID)
22
23 var doc *Doc
24 var err error
25 if doc, err = s.getDoc(ctx, key); err != nil {
26 return err
27 }
28
29 var pinData pin.Data
30 var previousPins []string
31
32
33
34 if doc != nil {
35 jsonErr := json.Unmarshal(doc.Value, &pinData)
36 if jsonErr != nil {
37 return errors.WithMessage(jsonErr, "invalid user pin schema detected")
38 }
39
40
41 previousPins = pinData.PreviousPins
42 previousPins = append(previousPins, pinData.Hash)
43
44 n := len(previousPins)
45
46 for i := n - 1; i >= n-int(config.PINHistoryLength()) && i >= 0; i-- {
47 hash := previousPins[i]
48 if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pincode)); err == nil {
49 return iamErrors.ErrPINPreviouslyUsed
50 }
51 }
52 }
53
54 hash, _ := bcrypt.GenerateFromPassword([]byte(pincode), config.BcryptCost())
55
56
57 if len(previousPins) > 5 {
58 previousPins = previousPins[1:]
59 }
60
61 userPIN := &pin.Data{
62 Subject: userID,
63 Hash: string(hash),
64 LastUpdated: time.Now().Unix(),
65 NumOfWrongAttempts: 0,
66 PreviousPins: previousPins,
67 }
68
69 payload, err := json.Marshal(userPIN)
70 if err != nil {
71 return errors.WithStack(err)
72 }
73
74 if err := s.updateDoc(ctx, key, payload, WithExpiration(config.GetPINTTL())); err != nil {
75 return errors.WithStack(err)
76 }
77
78 return nil
79 }
80
81 func (s *Store) LoginWithPIN(ctx context.Context, userID string, pincode string) (*pin.Data, error) {
82 key := keyFrom(KeyPrefixPIN, userID)
83
84 var doc *Doc
85 var err error
86 if doc, err = s.getDoc(ctx, key); err != nil {
87 return nil, err
88 }
89
90 if doc == nil {
91
92 fqn := ToFullyQualified(userID)
93 if doc, err = s.getDoc(ctx, fqn); err != nil {
94 return nil, err
95 }
96 if doc == nil {
97 return nil, iamErrors.ErrUserNotFound
98 }
99
100 if setErr := s.copyDoc(ctx, key, doc.Value, doc.Expiration); setErr != nil {
101
102 if !s.isOffline {
103 return nil, setErr
104 }
105 }
106
107
108 if err := s.deleteDoc(ctx, fqn); err != nil {
109 if !s.isOffline {
110 return nil, err
111 }
112 }
113 }
114
115 var userPIN pin.Data
116 res := doc.Value
117
118 jsonErr := json.Unmarshal(res, &userPIN)
119 if jsonErr != nil {
120 return nil, errors.WithMessage(jsonErr, "invalid user pin schema detected")
121 }
122
123
124 if strings.HasPrefix(userPIN.Subject, "acct:") {
125 return nil, iamErrors.ErrPINExpired
126 }
127
128
129 pinDuration := time.Since(time.Unix(userPIN.LastUpdated, 0))
130
131 if pinDuration > config.GetPINLifeSpan() {
132 return nil, iamErrors.ErrPINExpired
133 }
134
135
136 if userPIN.NumOfWrongAttempts > config.GetPINRetryThreshold()-1 {
137 return nil, iamErrors.ErrPINThresholdReached
138 }
139
140 compareErr := bcrypt.CompareHashAndPassword([]byte(userPIN.Hash), []byte(pincode))
141
142 if compareErr == nil {
143 if s.IsOffline() {
144
145 return &userPIN, nil
146 }
147
148
149 userPIN.NumOfWrongAttempts = 0
150
151 payload, marshalErr := json.Marshal(&userPIN)
152 if marshalErr != nil {
153 return nil, errors.WithStack(marshalErr)
154 }
155
156 if setErr := s.updateDoc(ctx, key, payload); setErr != nil {
157 return nil, setErr
158 }
159
160 return &userPIN, nil
161 }
162
163
164 if s.IsOffline() {
165
166 s.Log.Error(iamErrors.ErrIncorrectPIN, "offline detected. skipping the update of number of wrong login attempts ")
167 return nil, iamErrors.ErrIncorrectPIN
168 }
169
170
171 userPIN.NumOfWrongAttempts++
172 payload, marshalErr := json.Marshal(&userPIN)
173 if marshalErr != nil {
174 return nil, errors.WithStack(marshalErr)
175 }
176 if setErr := s.updateDoc(ctx, key, payload); setErr != nil {
177 return nil, setErr
178 }
179
180
181 if userPIN.NumOfWrongAttempts >= config.GetPINRetryThreshold() {
182 return nil, iamErrors.ErrPINThresholdReached
183 }
184
185 return nil, iamErrors.ErrIncorrectPIN
186 }
187
188
189 func ToFullyQualified(userID string) string {
190 fqn := userID
191 if !strings.Contains(userID, "pin:acct:") {
192 fqn = fmt.Sprintf("pin:acct:%v@%v", config.OrganizationName(), userID)
193 }
194
195 return fqn
196 }
197
View as plain text