...
1
21
22 package oauth2
23
24 import (
25 "context"
26
27 "github.com/ory/x/errorsx"
28
29 "github.com/pkg/errors"
30
31 "github.com/ory/fosite"
32 )
33
34 type TokenRevocationHandler struct {
35 TokenRevocationStorage TokenRevocationStorage
36 RefreshTokenStrategy RefreshTokenStrategy
37 AccessTokenStrategy AccessTokenStrategy
38 }
39
40
41
42 func (r *TokenRevocationHandler) RevokeToken(ctx context.Context, token string, tokenType fosite.TokenType, client fosite.Client) error {
43 discoveryFuncs := []func() (request fosite.Requester, err error){
44 func() (request fosite.Requester, err error) {
45
46 signature := r.RefreshTokenStrategy.RefreshTokenSignature(token)
47 return r.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
48 },
49 func() (request fosite.Requester, err error) {
50
51 signature := r.AccessTokenStrategy.AccessTokenSignature(token)
52 return r.TokenRevocationStorage.GetAccessTokenSession(ctx, signature, nil)
53 },
54 }
55
56
57 if tokenType == fosite.AccessToken {
58 discoveryFuncs[0], discoveryFuncs[1] = discoveryFuncs[1], discoveryFuncs[0]
59 }
60
61 var ar fosite.Requester
62 var err1, err2 error
63 if ar, err1 = discoveryFuncs[0](); err1 != nil {
64 ar, err2 = discoveryFuncs[1]()
65 }
66
67 if err2 != nil {
68 return storeErrorsToRevocationError(err1, err2)
69 }
70
71 if ar.GetClient().GetID() != client.GetID() {
72 return errorsx.WithStack(fosite.ErrUnauthorizedClient)
73 }
74
75 requestID := ar.GetID()
76 err1 = r.TokenRevocationStorage.RevokeRefreshToken(ctx, requestID)
77 err2 = r.TokenRevocationStorage.RevokeAccessToken(ctx, requestID)
78
79 return storeErrorsToRevocationError(err1, err2)
80 }
81
82 func storeErrorsToRevocationError(err1, err2 error) error {
83
84 if (errors.Is(err1, fosite.ErrNotFound) || err1 == nil) && (errors.Is(err2, fosite.ErrNotFound) || err2 == nil) {
85 return nil
86 }
87
88
89 return errorsx.WithStack(fosite.ErrTemporarilyUnavailable)
90 }
91
View as plain text