package oauth2
import (
"context"
"html"
"github.com/ory/fosite"
"github.com/ory/fosite/compose"
oauth "github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/token/hmac"
"github.com/ory/x/errorsx"
"edge-infra.dev/pkg/edge/iam/config"
iamConfig "edge-infra.dev/pkg/edge/iam/config"
"edge-infra.dev/pkg/edge/iam/device"
"edge-infra.dev/pkg/edge/iam/errors"
"edge-infra.dev/pkg/edge/iam/identity"
"edge-infra.dev/pkg/edge/iam/profile"
"edge-infra.dev/pkg/edge/iam/session"
"edge-infra.dev/pkg/edge/iam/util"
)
type RefreshTokenVerifyHandler struct {
LoginHintStrategy *hmac.HMACStrategy
LoginSessionStorage session.LoginSessionStorage
TokenRevocationStorage oauth.TokenRevocationStorage
ProfileStorage profile.Storage
RefreshTokenStrategy oauth.RefreshTokenStrategy
deviceStorage device.Storage
}
func RefreshTokenVerifyFactory(_ *compose.Config, storage interface{}, strategy interface{}) interface{} {
return &RefreshTokenVerifyHandler{
ProfileStorage: storage.(profile.Storage),
LoginHintStrategy: iamConfig.HMACStrategy(),
LoginSessionStorage: storage.(session.LoginSessionStorage),
TokenRevocationStorage: storage.(oauth.TokenRevocationStorage),
RefreshTokenStrategy: strategy.(oauth.RefreshTokenStrategy),
deviceStorage: storage.(device.Storage),
}
}
func (h *RefreshTokenVerifyHandler) HandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) error {
s := session.FromRequester(requester)
// Need refresh token to get the original request, which helps us know the subject to whom this is issued to.
refreshToken := requester.GetRequestForm().Get("refresh_token")
signature := h.RefreshTokenStrategy.RefreshTokenSignature(refreshToken)
// No need to handle error as base refresh token flow already handles it.
originalRequest, _ := h.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, requester.GetSession())
// Subject won't exist in current session, must be retreived from JWT claims associated with refresh session.
jwtSession := originalRequest.GetSession().(oauth.JWTSessionContainer)
// Get the profile with the subject in JWTClaims.
userProfile, err := h.ProfileStorage.GetIdentityProfile(ctx, jwtSession.GetJWTClaims().ToMapClaims()["sub"].(string))
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.
WithWrap(err).
WithDebug("storage error"))
}
if userProfile == nil {
return errorsx.WithStack(fosite.ErrServerError.
WithWrap(err).
WithDebug("invalid subject"))
}
clientID := requester.GetClient().GetID()
if config.DeviceLoginEnabled() {
if err := h.verifyDeviceAuth(ctx, s, userProfile, clientID); err != nil {
// treat a barcode auth, without valid cloud refresh token as an expired barcode.
return err
}
} else {
if err := h.verifyProfile(s, userProfile, clientID); err != nil {
return err
}
}
return nil
}
func (h *RefreshTokenVerifyHandler) PopulateTokenEndpointResponse(_ context.Context, _ fosite.AccessRequester, _ fosite.AccessResponder) (err error) {
return nil
}
func (h *RefreshTokenVerifyHandler) CanSkipClientAuth(_ fosite.AccessRequester) bool {
return false
}
func (h *RefreshTokenVerifyHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool {
return requester.GetGrantTypes().ExactOne("refresh_token")
}
func (h *RefreshTokenVerifyHandler) verifyDeviceAuth(ctx context.Context, s session.Session, userProfile *profile.Profile, clientID string) error {
deviceLogin := userProfile.DeviceLogin
acc, err := h.deviceStorage.GetDeviceAccount(ctx, deviceLogin)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.
WithWrap(err).
WithDebug("error fetching device account from storage"))
}
if acc == nil {
return errorsx.WithStack(fosite.ErrRequestUnauthorized.
WithWrap(h.generateLoginHint(s, clientID, errors.ErrLoginRequired, "", identity.RefreshTokenInvalid)).
WithDebug("no such device account"))
}
// todo: inject into struct, may be via cloud svc struct, so we can use the same signature for BSL/device/okta etc.
err = hasValidRefreshToken(acc.RefreshToken)
if err != nil {
return errorsx.WithStack(fosite.ErrRequestUnauthorized.
WithWrap(h.generateLoginHint(s, clientID, errors.ErrLoginRequired, "", identity.RefreshTokenInvalid)))
}
return nil
}
func (h *RefreshTokenVerifyHandler) verifyProfile(s session.Session, userProfile *profile.Profile, clientID string) error {
verify, err := userProfile.RequireVerification(h.ProfileStorage.IsOffline())
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.
WithWrap(err).
WithDebug("network error"))
}
if verify {
return h.generateLoginHint(s, clientID, errors.ErrLoginRequired, userProfile.Subject, identity.ProfileDataExpired)
}
return nil
}
func (h *RefreshTokenVerifyHandler) generateLoginHint(s session.Session, clientID string, displayErr error, subject string, reason int) error {
challenge, signature, err := h.LoginHintStrategy.Generate()
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithHint(err.Error()))
}
s.SetChallenge(challenge)
loginSession := &session.LoginSession{
Subject: subject,
Reason: reason,
Active: true,
ClientID: clientID,
}
if displayErr != nil {
loginSession.LoginOptions.ErrorMessage = html.EscapeString(displayErr.Error())
}
err = h.LoginSessionStorage.SetLoginSession(signature, loginSession)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithHint(err.Error()))
}
return errorsx.WithStack(fosite.ErrRequestUnauthorized.WithWrap(displayErr))
}
func hasValidRefreshToken(token string) error {
if util.IsDeviceLoginAvailable() {
return device.ExchangeRefreshToken(token)
}
return device.IsRefreshTokenValid(token)
}