1 package device
2
3 import (
4 "encoding/json"
5 "fmt"
6 "io"
7 "net/http"
8 "strings"
9 "time"
10
11 keyfunc "github.com/MicahParks/keyfunc/v2"
12 "github.com/gin-gonic/gin"
13 "github.com/golang-jwt/jwt/v5"
14 "github.com/google/uuid"
15 "github.com/gorilla/sessions"
16 "golang.org/x/crypto/bcrypt"
17
18 "edge-infra.dev/pkg/edge/iam/apperror"
19 "edge-infra.dev/pkg/edge/iam/config"
20 "edge-infra.dev/pkg/edge/iam/log"
21 "edge-infra.dev/pkg/edge/iam/profile"
22 "edge-infra.dev/pkg/edge/iam/prometheus"
23 "edge-infra.dev/pkg/edge/iam/util"
24 )
25
26
27
28 var (
29 reqPIN = "pin"
30 signInPin = "sign_in_pin"
31 signUpPin = "sign_up_pin"
32 )
33
34
35 type loginForm struct {
36 Username string `form:"username"`
37 Password string `form:"password"`
38 Reason string `form:"reason"`
39 }
40
41 type AuthMethod struct {
42 service *CloudService
43 storage Storage
44 profileStorage profile.Storage
45 sessionStore sessions.Store
46 metrics *prometheus.Metrics
47 }
48
49 func NewAuthMethod(
50 router *gin.Engine,
51 service *CloudService,
52 sessionStore sessions.Store,
53 storage interface{},
54 metrics *prometheus.Metrics,
55 ) *AuthMethod {
56 am := &AuthMethod{
57 storage: storage.(Storage),
58 profileStorage: storage.(profile.Storage),
59 sessionStore: sessionStore,
60 metrics: metrics,
61 service: service,
62 }
63 router.POST("/idp/set/device", util.MakeHandlerFunc(am.selfService))
64 router.POST("/idp/login/device", util.MakeHandlerFunc(am.login))
65 if !config.IsProduction() {
66 router.Any("/disrupt/device", util.MakeHandlerFunc(am.setDisrupt))
67 }
68 return am
69 }
70
71 func (am *AuthMethod) login(c *gin.Context) error {
72 am.metrics.IncHTTPRequestsTotal(reqPIN)
73 logger := log.Get(c.Request.Context()).WithName("device-login")
74
75
76 var form loginForm
77 if err := c.ShouldBind(&form); err != nil {
78 am.metrics.IncSignInRequestsTotal(signInPin, util.Failed)
79 return apperror.NewAbortError(
80 fmt.Errorf("failed to bind login based on method and content-type: %w", err),
81 http.StatusBadRequest)
82 }
83
84 tokenSet, err := am.service.Login(form.Username, form.Password)
85
86 if err != nil && err == ErrDeviceLoginDenied {
87 return apperror.NewStatusError(err, http.StatusUnauthorized)
88 }
89
90 if err != nil && err == ErrDeviceLoginForbidden {
91 return apperror.NewStatusError(err, http.StatusForbidden)
92 }
93
94 var accessClaims, idClaims map[string]interface{}
95 if len(tokenSet.AccessToken) > 0 {
96 var validationErr error
97 accessClaims, validationErr = am.ValidateToken(tokenSet.AccessToken)
98 if validationErr != nil {
99 return apperror.NewStatusError(fmt.Errorf("invalid access token from site security: %w", validationErr), http.StatusInternalServerError)
100 }
101 }
102 if len(tokenSet.IDToken) > 0 {
103 var validationErr error
104 idClaims, validationErr = am.ValidateToken(tokenSet.IDToken)
105 if validationErr != nil {
106 return apperror.NewStatusError(fmt.Errorf("invalid id token from site security: %w", validationErr), http.StatusInternalServerError)
107 }
108 }
109
110
111 if err != nil {
112 logger.Error(err, "falling back to local authentication for device login")
113 return am.localLogin(c, form.Username, form.Password, form.Reason)
114 }
115
116 session, _ := am.sessionStore.Get(c.Request, "oauth2")
117 session.Values["device_token"] = tokenSet.AccessToken
118 if err = session.Save(c.Request, c.Writer); err != nil {
119 return apperror.NewAbortError(
120 fmt.Errorf("failed to save cookie session: %w", err),
121 http.StatusInternalServerError)
122 }
123
124 if tokenSet.Warnings.PasswordExpired {
125 return apperror.NewJSONError(err,
126 http.StatusUnauthorized,
127 "expired device password",
128 gin.H{"error": "expired_device_password"},
129 )
130 }
131 if tokenSet.Warnings.PasswordMustChange {
132 return apperror.NewJSONError(err,
133 http.StatusUnauthorized,
134 "device password must change",
135 gin.H{"error": "device_password_must_change"},
136 )
137 }
138 alias, _ := util.RandomStringGenerator(8)
139 existingProfile, err := am.profileStorage.GetIdentityProfile(c, accessClaims["sub"].(string))
140 if err != nil {
141 logger.Error(err, "error fetching identity profile, creating new alias for profile")
142 } else {
143 if existingProfile != nil && existingProfile.Alias != "" {
144 alias = existingProfile.Alias
145 }
146 }
147
148
149 userProfile := profile.Profile{
150 Subject: accessClaims["sub"].(string),
151 Organization: accessClaims["org"].(string),
152 Roles: accessClaims["rls"].(string),
153 GivenName: idClaims["given_name"].(string),
154 FamilyName: idClaims["family_name"].(string),
155 FullName: idClaims["name"].(string),
156 Email: idClaims["email"].(string),
157 DeviceLogin: form.Username,
158 Alias: alias,
159 }
160
161
162 tokenSet.IDToken = ""
163
164 setAgeInProfile(idClaims, &userProfile, time.Now, config.TimeZone())
165
166
167 if addressMap, exists := idClaims["address"]; exists {
168 addressJSON, _ := json.Marshal(addressMap)
169
170 addressClaim := profile.AddressClaim{}
171 err := json.Unmarshal(addressJSON, &addressClaim)
172 if err == nil {
173 userProfile.Address = &addressClaim
174 }
175 }
176
177
178 if err = am.profileStorage.CreateIdentityProfile(c, userProfile); err != nil && !am.profileStorage.IsOffline() {
179 return apperror.NewAbortError(
180 fmt.Errorf("failed to store the identity: %w", err),
181 http.StatusInternalServerError)
182 }
183
184 hash, err := bcrypt.GenerateFromPassword([]byte(form.Password), config.BcryptCost())
185 if err != nil {
186 return apperror.NewAbortError(fmt.Errorf("failed to hash password: %w", err), http.StatusInternalServerError)
187 }
188 account := Account{
189 TokenSet: tokenSet,
190 Username: strings.ToLower(form.Username),
191 Subject: userProfile.Subject,
192 Hash: string(hash),
193 LastUpdated: time.Now().Unix(),
194 NumOfWrongAttempts: 0,
195 }
196 if err = am.storage.SaveDeviceAccount(c, account); err != nil && !am.profileStorage.IsOffline() {
197 return apperror.NewAbortError(
198 fmt.Errorf("failed to save device account: %w", err),
199 http.StatusInternalServerError)
200 }
201
202 if err = am.profileStorage.CreateAlias(c, alias, accessClaims["sub"].(string)); err != nil {
203
204
205
206 logger.Error(err, "failed to create alias")
207 }
208
209 session.Values["method"] = "device"
210 session.Values["reason"] = form.Reason
211 continuation := uuid.New().String()
212 session.Values["continuation"] = continuation
213
214 am.setProfileOnSession(session, &userProfile)
215
216 if err = session.Save(c.Request, c.Writer); err != nil {
217 return apperror.NewAbortError(
218 fmt.Errorf("failed to save cookie session: %w", err),
219 http.StatusInternalServerError)
220 }
221
222 return util.WriteJSON(c.Writer, http.StatusOK, gin.H{
223 "challenge": continuation,
224 })
225 }
226
227 func setAgeInProfile(idClaims map[string]interface{}, userProfile *profile.Profile, nowFunc func() time.Time, timezone string) {
228
229 _, dobExists := idClaims["dob"]
230 _, birthDateExists := idClaims["birthdate"]
231 if dobExists {
232
233 if age, err := util.CalculateAge(idClaims["dob"].(string), nowFunc, timezone); err == nil {
234 userProfile.Age = age
235 }
236 } else if birthDateExists {
237
238 if age, err := util.CalculateAge(idClaims["birthdate"].(string), nowFunc, timezone); err == nil {
239 userProfile.Age = age
240 }
241 }
242 }
243
244 func (am *AuthMethod) jwks() ([]byte, error) {
245 client := &http.Client{}
246 req, err := http.NewRequest("GET", config.DeviceBaseURL()+"/jwks", nil)
247 if err != nil {
248 return nil, err
249 }
250 res, err := client.Do(req)
251 if err != nil {
252 return nil, err
253 }
254 defer res.Body.Close()
255
256 body, err := io.ReadAll(res.Body)
257 if err != nil {
258 return nil, err
259 }
260 am.service.jwks = body
261 var decodedJWKS map[string]interface{}
262 err = json.Unmarshal(am.service.jwks, &decodedJWKS)
263 if err != nil {
264 return nil, fmt.Errorf("invalid jwks")
265 }
266 am.service.keyIDs = nil
267 keys := decodedJWKS["keys"].([]interface{})
268 for i := 0; i < len(keys); i++ {
269 key := keys[i].(map[string]interface{})
270 keyID := key["kid"].(string)
271 am.service.keyIDs = append(am.service.keyIDs, keyID)
272 }
273 return body, nil
274 }
275
276 func (am *AuthMethod) ValidateToken(token string) (map[string]interface{}, error) {
277
278 decodedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
279 if err != nil {
280 return nil, err
281 }
282 header := decodedToken.Header
283
284 _, ok := header["kid"]
285 if !ok {
286 return nil, jwt.ErrTokenUnverifiable
287 }
288
289 keyIDInHeader, ok := header["kid"].(string)
290 if !ok {
291 return nil, jwt.ErrTokenUnverifiable
292 }
293 if !util.IsElementExist(am.service.keyIDs, keyIDInHeader) {
294
295 am.service.jwks, err = am.jwks()
296 if err != nil {
297 return nil, err
298 }
299 }
300
301 jwk, err := keyfunc.NewJSON(am.service.jwks)
302 if err != nil {
303 return nil, err
304 }
305 res, err := jwt.Parse(token, jwk.Keyfunc, jwt.WithLeeway(config.GetLeeWayForDeviceTokenValidation()))
306 if err != nil {
307 return nil, err
308 }
309
310 claims, ok := res.Claims.(jwt.MapClaims)
311 if !ok {
312
313 return nil, jwt.ErrTokenInvalidClaims
314 }
315 return claims, nil
316 }
317 func (*AuthMethod) setProfileOnSession(session *sessions.Session, userProfile *profile.Profile) {
318 session.Values["alias"] = userProfile.Alias
319 session.Values["sub"] = userProfile.Subject
320 session.Values["org"] = userProfile.Organization
321 session.Values["rls"] = userProfile.Roles
322 session.Values["gn"] = userProfile.GivenName
323 session.Values["fn"] = userProfile.FamilyName
324 session.Values["n"] = userProfile.FullName
325 session.Values["age"] = userProfile.Age
326 session.Values["device_login"] = userProfile.DeviceLogin
327 session.Values["email"] = userProfile.Email
328 if userProfile.Address != nil {
329 addressClaimJSON, err := json.Marshal(userProfile.Address)
330 if err == nil {
331 session.Values["address"] = string(addressClaimJSON)
332 }
333 }
334 }
335
View as plain text