...
1 package oauth2
2
3 import (
4 "errors"
5 "fmt"
6 "net/http"
7
8 "edge-infra.dev/pkg/edge/iam/apperror"
9 iamErrors "edge-infra.dev/pkg/edge/iam/errors"
10 "edge-infra.dev/pkg/edge/iam/log"
11 "edge-infra.dev/pkg/edge/iam/session"
12
13 "github.com/gin-gonic/gin"
14 "github.com/ory/fosite"
15 )
16
17 func (oauth2 *OAuth2) token(ctx *gin.Context) error {
18 log := log.Get(ctx.Request.Context())
19 newSession := session.NewSession("")
20
21 requester, err := oauth2.fosite.NewAccessRequest(ctx, ctx.Request, newSession)
22 if err != nil {
23 return oauth2.handleTokenError(ctx, requester, err)
24 }
25
26 responder, err := oauth2.fosite.NewAccessResponse(ctx, requester)
27 if err != nil {
28 log.Error(err, "failed to create access response")
29 oauth2.fosite.WriteAccessError(ctx.Writer, requester, err)
30 return nil
31 }
32
33 oauth2.fosite.WriteAccessResponse(ctx.Writer, requester, responder)
34 return nil
35 }
36
37 func (oauth2 *OAuth2) handleTokenError(ctx *gin.Context, requester fosite.AccessRequester, err error) error {
38 log := log.Get(ctx.Request.Context())
39 rfcErr, ok := err.(*fosite.RFC6749Error)
40 if ok {
41 if unWrapErr := rfcErr.Unwrap(); requiresLoginChallenge(unWrapErr) {
42 resultantSession, ok := requester.GetSession().(*session.DefaultSession)
43 if !ok {
44 return apperror.NewStatusError(
45 fmt.Errorf("request requires login. missing or invalid request session: %w", err),
46 http.StatusInternalServerError)
47 }
48
49 ctx.Header("WWW-Authenticate", fmt.Sprintf(`Login %s`, resultantSession.GetChallenge()))
50
51 return apperror.NewStatusError(
52 fmt.Errorf("request requires login. returning a login challenge: %w", unWrapErr),
53 http.StatusUnauthorized)
54 }
55 log.Error(err, fmt.Sprintf("failed to create access request, %v", rfcErr.Unwrap().Error()))
56 } else {
57 log.Error(err, "failed to create access request")
58 }
59 oauth2.fosite.WriteAccessError(ctx.Writer, requester, err)
60 return nil
61 }
62
63 func requiresLoginChallenge(err error) bool {
64 expectedErrors := []error{iamErrors.ErrLoginRequired, iamErrors.ErrExpiredBarcode, iamErrors.ErrUnrecognisedBarcode, iamErrors.ErrInvalidEBCUsage}
65 for _, expectedErr := range expectedErrors {
66 if errors.Is(err, expectedErr) {
67 return true
68 }
69 }
70 return false
71 }
72
View as plain text