...
1
2
3
4 package grpc_auth
5
6 import (
7 "context"
8
9 "github.com/grpc-ecosystem/go-grpc-middleware"
10 "google.golang.org/grpc"
11 )
12
13
14
15
16
17
18
19
20
21
22
23
24 type AuthFunc func(ctx context.Context) (context.Context, error)
25
26
27
28
29
30 type ServiceAuthFuncOverride interface {
31 AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error)
32 }
33
34
35 func UnaryServerInterceptor(authFunc AuthFunc) grpc.UnaryServerInterceptor {
36 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
37 var newCtx context.Context
38 var err error
39 if overrideSrv, ok := info.Server.(ServiceAuthFuncOverride); ok {
40 newCtx, err = overrideSrv.AuthFuncOverride(ctx, info.FullMethod)
41 } else {
42 newCtx, err = authFunc(ctx)
43 }
44 if err != nil {
45 return nil, err
46 }
47 return handler(newCtx, req)
48 }
49 }
50
51
52 func StreamServerInterceptor(authFunc AuthFunc) grpc.StreamServerInterceptor {
53 return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
54 var newCtx context.Context
55 var err error
56 if overrideSrv, ok := srv.(ServiceAuthFuncOverride); ok {
57 newCtx, err = overrideSrv.AuthFuncOverride(stream.Context(), info.FullMethod)
58 } else {
59 newCtx, err = authFunc(stream.Context())
60 }
61 if err != nil {
62 return err
63 }
64 wrapped := grpc_middleware.WrapServerStream(stream)
65 wrapped.WrappedContext = newCtx
66 return handler(srv, wrapped)
67 }
68 }
69
View as plain text