package oauth2 import ( "context" "html" "github.com/ory/fosite" "github.com/ory/fosite/compose" oauth "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/token/hmac" "github.com/ory/x/errorsx" "edge-infra.dev/pkg/edge/iam/config" iamConfig "edge-infra.dev/pkg/edge/iam/config" "edge-infra.dev/pkg/edge/iam/device" "edge-infra.dev/pkg/edge/iam/errors" "edge-infra.dev/pkg/edge/iam/identity" "edge-infra.dev/pkg/edge/iam/profile" "edge-infra.dev/pkg/edge/iam/session" "edge-infra.dev/pkg/edge/iam/util" ) type RefreshTokenVerifyHandler struct { LoginHintStrategy *hmac.HMACStrategy LoginSessionStorage session.LoginSessionStorage TokenRevocationStorage oauth.TokenRevocationStorage ProfileStorage profile.Storage RefreshTokenStrategy oauth.RefreshTokenStrategy deviceStorage device.Storage } func RefreshTokenVerifyFactory(_ *compose.Config, storage interface{}, strategy interface{}) interface{} { return &RefreshTokenVerifyHandler{ ProfileStorage: storage.(profile.Storage), LoginHintStrategy: iamConfig.HMACStrategy(), LoginSessionStorage: storage.(session.LoginSessionStorage), TokenRevocationStorage: storage.(oauth.TokenRevocationStorage), RefreshTokenStrategy: strategy.(oauth.RefreshTokenStrategy), deviceStorage: storage.(device.Storage), } } func (h *RefreshTokenVerifyHandler) HandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) error { s := session.FromRequester(requester) // Need refresh token to get the original request, which helps us know the subject to whom this is issued to. refreshToken := requester.GetRequestForm().Get("refresh_token") signature := h.RefreshTokenStrategy.RefreshTokenSignature(refreshToken) // No need to handle error as base refresh token flow already handles it. originalRequest, _ := h.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, requester.GetSession()) // Subject won't exist in current session, must be retreived from JWT claims associated with refresh session. jwtSession := originalRequest.GetSession().(oauth.JWTSessionContainer) // Get the profile with the subject in JWTClaims. userProfile, err := h.ProfileStorage.GetIdentityProfile(ctx, jwtSession.GetJWTClaims().ToMapClaims()["sub"].(string)) if err != nil { return errorsx.WithStack(fosite.ErrServerError. WithWrap(err). WithDebug("storage error")) } if userProfile == nil { return errorsx.WithStack(fosite.ErrServerError. WithWrap(err). WithDebug("invalid subject")) } clientID := requester.GetClient().GetID() if config.DeviceLoginEnabled() { if err := h.verifyDeviceAuth(ctx, s, userProfile, clientID); err != nil { // treat a barcode auth, without valid cloud refresh token as an expired barcode. return err } } else { if err := h.verifyProfile(s, userProfile, clientID); err != nil { return err } } return nil } func (h *RefreshTokenVerifyHandler) PopulateTokenEndpointResponse(_ context.Context, _ fosite.AccessRequester, _ fosite.AccessResponder) (err error) { return nil } func (h *RefreshTokenVerifyHandler) CanSkipClientAuth(_ fosite.AccessRequester) bool { return false } func (h *RefreshTokenVerifyHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { return requester.GetGrantTypes().ExactOne("refresh_token") } func (h *RefreshTokenVerifyHandler) verifyDeviceAuth(ctx context.Context, s session.Session, userProfile *profile.Profile, clientID string) error { deviceLogin := userProfile.DeviceLogin acc, err := h.deviceStorage.GetDeviceAccount(ctx, deviceLogin) if err != nil { return errorsx.WithStack(fosite.ErrServerError. WithWrap(err). WithDebug("error fetching device account from storage")) } if acc == nil { return errorsx.WithStack(fosite.ErrRequestUnauthorized. WithWrap(h.generateLoginHint(s, clientID, errors.ErrLoginRequired, "", identity.RefreshTokenInvalid)). WithDebug("no such device account")) } // todo: inject into struct, may be via cloud svc struct, so we can use the same signature for BSL/device/okta etc. err = hasValidRefreshToken(acc.RefreshToken) if err != nil { return errorsx.WithStack(fosite.ErrRequestUnauthorized. WithWrap(h.generateLoginHint(s, clientID, errors.ErrLoginRequired, "", identity.RefreshTokenInvalid))) } return nil } func (h *RefreshTokenVerifyHandler) verifyProfile(s session.Session, userProfile *profile.Profile, clientID string) error { verify, err := userProfile.RequireVerification(h.ProfileStorage.IsOffline()) if err != nil { return errorsx.WithStack(fosite.ErrServerError. WithWrap(err). WithDebug("network error")) } if verify { return h.generateLoginHint(s, clientID, errors.ErrLoginRequired, userProfile.Subject, identity.ProfileDataExpired) } return nil } func (h *RefreshTokenVerifyHandler) generateLoginHint(s session.Session, clientID string, displayErr error, subject string, reason int) error { challenge, signature, err := h.LoginHintStrategy.Generate() if err != nil { return errorsx.WithStack(fosite.ErrServerError.WithHint(err.Error())) } s.SetChallenge(challenge) loginSession := &session.LoginSession{ Subject: subject, Reason: reason, Active: true, ClientID: clientID, } if displayErr != nil { loginSession.LoginOptions.ErrorMessage = html.EscapeString(displayErr.Error()) } err = h.LoginSessionStorage.SetLoginSession(signature, loginSession) if err != nil { return errorsx.WithStack(fosite.ErrServerError.WithHint(err.Error())) } return errorsx.WithStack(fosite.ErrRequestUnauthorized.WithWrap(displayErr)) } func hasValidRefreshToken(token string) error { if util.IsDeviceLoginAvailable() { return device.ExchangeRefreshToken(token) } return device.IsRefreshTokenValid(token) }