1 package oauth2
2
3 import (
4 "context"
5 "html"
6
7 "github.com/ory/fosite"
8 "github.com/ory/fosite/compose"
9 oauth "github.com/ory/fosite/handler/oauth2"
10 "github.com/ory/fosite/token/hmac"
11 "github.com/ory/x/errorsx"
12
13 "edge-infra.dev/pkg/edge/iam/config"
14 iamConfig "edge-infra.dev/pkg/edge/iam/config"
15 "edge-infra.dev/pkg/edge/iam/device"
16 "edge-infra.dev/pkg/edge/iam/errors"
17 "edge-infra.dev/pkg/edge/iam/identity"
18 "edge-infra.dev/pkg/edge/iam/profile"
19 "edge-infra.dev/pkg/edge/iam/session"
20 "edge-infra.dev/pkg/edge/iam/util"
21 )
22
23 type RefreshTokenVerifyHandler struct {
24 LoginHintStrategy *hmac.HMACStrategy
25 LoginSessionStorage session.LoginSessionStorage
26 TokenRevocationStorage oauth.TokenRevocationStorage
27 ProfileStorage profile.Storage
28 RefreshTokenStrategy oauth.RefreshTokenStrategy
29 deviceStorage device.Storage
30 }
31
32 func RefreshTokenVerifyFactory(_ *compose.Config, storage interface{}, strategy interface{}) interface{} {
33 return &RefreshTokenVerifyHandler{
34 ProfileStorage: storage.(profile.Storage),
35 LoginHintStrategy: iamConfig.HMACStrategy(),
36 LoginSessionStorage: storage.(session.LoginSessionStorage),
37 TokenRevocationStorage: storage.(oauth.TokenRevocationStorage),
38 RefreshTokenStrategy: strategy.(oauth.RefreshTokenStrategy),
39 deviceStorage: storage.(device.Storage),
40 }
41 }
42
43 func (h *RefreshTokenVerifyHandler) HandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) error {
44 s := session.FromRequester(requester)
45
46 refreshToken := requester.GetRequestForm().Get("refresh_token")
47 signature := h.RefreshTokenStrategy.RefreshTokenSignature(refreshToken)
48
49 originalRequest, _ := h.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, requester.GetSession())
50
51 jwtSession := originalRequest.GetSession().(oauth.JWTSessionContainer)
52
53 userProfile, err := h.ProfileStorage.GetIdentityProfile(ctx, jwtSession.GetJWTClaims().ToMapClaims()["sub"].(string))
54 if err != nil {
55 return errorsx.WithStack(fosite.ErrServerError.
56 WithWrap(err).
57 WithDebug("storage error"))
58 }
59 if userProfile == nil {
60 return errorsx.WithStack(fosite.ErrServerError.
61 WithWrap(err).
62 WithDebug("invalid subject"))
63 }
64 clientID := requester.GetClient().GetID()
65 if config.DeviceLoginEnabled() {
66 if err := h.verifyDeviceAuth(ctx, s, userProfile, clientID); err != nil {
67
68 return err
69 }
70 } else {
71 if err := h.verifyProfile(s, userProfile, clientID); err != nil {
72 return err
73 }
74 }
75 return nil
76 }
77
78 func (h *RefreshTokenVerifyHandler) PopulateTokenEndpointResponse(_ context.Context, _ fosite.AccessRequester, _ fosite.AccessResponder) (err error) {
79 return nil
80 }
81
82 func (h *RefreshTokenVerifyHandler) CanSkipClientAuth(_ fosite.AccessRequester) bool {
83 return false
84 }
85
86 func (h *RefreshTokenVerifyHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool {
87 return requester.GetGrantTypes().ExactOne("refresh_token")
88 }
89
90 func (h *RefreshTokenVerifyHandler) verifyDeviceAuth(ctx context.Context, s session.Session, userProfile *profile.Profile, clientID string) error {
91 deviceLogin := userProfile.DeviceLogin
92 acc, err := h.deviceStorage.GetDeviceAccount(ctx, deviceLogin)
93 if err != nil {
94 return errorsx.WithStack(fosite.ErrServerError.
95 WithWrap(err).
96 WithDebug("error fetching device account from storage"))
97 }
98 if acc == nil {
99 return errorsx.WithStack(fosite.ErrRequestUnauthorized.
100 WithWrap(h.generateLoginHint(s, clientID, errors.ErrLoginRequired, "", identity.RefreshTokenInvalid)).
101 WithDebug("no such device account"))
102 }
103
104 err = hasValidRefreshToken(acc.RefreshToken)
105 if err != nil {
106 return errorsx.WithStack(fosite.ErrRequestUnauthorized.
107 WithWrap(h.generateLoginHint(s, clientID, errors.ErrLoginRequired, "", identity.RefreshTokenInvalid)))
108 }
109 return nil
110 }
111
112 func (h *RefreshTokenVerifyHandler) verifyProfile(s session.Session, userProfile *profile.Profile, clientID string) error {
113 verify, err := userProfile.RequireVerification(h.ProfileStorage.IsOffline())
114 if err != nil {
115 return errorsx.WithStack(fosite.ErrServerError.
116 WithWrap(err).
117 WithDebug("network error"))
118 }
119
120 if verify {
121 return h.generateLoginHint(s, clientID, errors.ErrLoginRequired, userProfile.Subject, identity.ProfileDataExpired)
122 }
123 return nil
124 }
125
126 func (h *RefreshTokenVerifyHandler) generateLoginHint(s session.Session, clientID string, displayErr error, subject string, reason int) error {
127 challenge, signature, err := h.LoginHintStrategy.Generate()
128 if err != nil {
129 return errorsx.WithStack(fosite.ErrServerError.WithHint(err.Error()))
130 }
131
132 s.SetChallenge(challenge)
133
134 loginSession := &session.LoginSession{
135 Subject: subject,
136 Reason: reason,
137 Active: true,
138 ClientID: clientID,
139 }
140
141 if displayErr != nil {
142 loginSession.LoginOptions.ErrorMessage = html.EscapeString(displayErr.Error())
143 }
144
145 err = h.LoginSessionStorage.SetLoginSession(signature, loginSession)
146
147 if err != nil {
148 return errorsx.WithStack(fosite.ErrServerError.WithHint(err.Error()))
149 }
150
151 return errorsx.WithStack(fosite.ErrRequestUnauthorized.WithWrap(displayErr))
152 }
153
154 func hasValidRefreshToken(token string) error {
155 if util.IsDeviceLoginAvailable() {
156 return device.ExchangeRefreshToken(token)
157 }
158 return device.IsRefreshTokenValid(token)
159 }
160
View as plain text