...
1 package casbin
2
3 import (
4 "context"
5 "errors"
6
7 stdcasbin "github.com/casbin/casbin/v2"
8 "github.com/go-kit/kit/endpoint"
9 )
10
11 type contextKey string
12
13 const (
14
15
16
17 CasbinModelContextKey contextKey = "CasbinModel"
18
19
20
21
22 CasbinPolicyContextKey contextKey = "CasbinPolicy"
23
24
25
26 CasbinEnforcerContextKey contextKey = "CasbinEnforcer"
27 )
28
29 var (
30
31
32 ErrModelContextMissing = errors.New("CasbinModel is required in context")
33
34
35
36 ErrPolicyContextMissing = errors.New("CasbinPolicy is required in context")
37
38
39
40 ErrUnauthorized = errors.New("Unauthorized Access")
41 )
42
43
44
45
46
47 func NewEnforcer(
48 subject string, object interface{}, action string,
49 ) endpoint.Middleware {
50 return func(next endpoint.Endpoint) endpoint.Endpoint {
51 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
52 casbinModel := ctx.Value(CasbinModelContextKey)
53 casbinPolicy := ctx.Value(CasbinPolicyContextKey)
54 enforcer, err := stdcasbin.NewEnforcer(casbinModel, casbinPolicy)
55 if err != nil {
56 return nil, err
57 }
58
59 ctx = context.WithValue(ctx, CasbinEnforcerContextKey, enforcer)
60 ok, err := enforcer.Enforce(subject, object, action)
61 if err != nil {
62 return nil, err
63 }
64 if !ok {
65 return nil, ErrUnauthorized
66 }
67
68 return next(ctx, request)
69 }
70 }
71 }
72
View as plain text