1 package interceptor
2
3 import (
4 "bytes"
5 "encoding/json"
6 "errors"
7 "io"
8 "net/http"
9 "regexp"
10
11 "maps"
12
13 "github.com/99designs/gqlgen/graphql"
14 "github.com/gin-contrib/sessions"
15 "github.com/vektah/gqlparser/v2/formatter"
16
17 "edge-infra.dev/pkg/edge/api/graphqlhelpers"
18 "edge-infra.dev/pkg/edge/audit"
19 "edge-infra.dev/pkg/edge/auth-proxy/types"
20 "edge-infra.dev/pkg/edge/bsl"
21 )
22
23 type Factory struct {
24 operations map[string]Operation
25 auditor *audit.Sink
26 session sessions.Session
27 correlationID string
28 ip string
29 }
30
31 type Operation struct {
32 name string
33 identifier op
34 hdlr func(*http.Request, []byte) (*http.Request, []byte, error)
35 matcher string
36 cb bool
37 }
38
39 type RequestQuery struct {
40 Query string `json:"query"`
41 Variables map[string]interface{} `json:"variables"`
42 }
43
44 var (
45 ErrInterceptorNotImplemented = errors.New("interceptor not implemented")
46 )
47
48 func New(ip, correlationID string, sess sessions.Session) *Factory {
49 return &Factory{
50 operations: make(map[string]Operation, 0),
51 auditor: audit.New("auth-proxy"),
52 session: sess,
53 correlationID: correlationID,
54 ip: ip,
55 }
56 }
57
58 func (f *Factory) Query(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), matcher string, cb bool) {
59 f.operations[op] = Operation{
60 name: op,
61 identifier: Query,
62 hdlr: hdlr,
63 matcher: matcher,
64 cb: cb,
65 }
66 }
67
68 func (f *Factory) Mutation(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), matcher string, cb bool) {
69 f.operations[op] = Operation{
70 name: op,
71 identifier: Query,
72 hdlr: hdlr,
73 matcher: matcher,
74 cb: cb,
75 }
76 }
77
78 func (f *Factory) Subscription(_ string, _ func(*http.Request, []byte) (*http.Request, []byte, error), _ string, _ bool) error {
79 return ErrInterceptorNotImplemented
80 }
81
82 func (f *Factory) Default(hdlr func(*http.Request, []byte) (*http.Request, []byte, error), cb bool) {
83 f.operations["default"] = Operation{
84 identifier: Default,
85 hdlr: hdlr,
86 cb: cb,
87 }
88 }
89
90
91 func (f *Factory) Path(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), path string, cb bool) {
92 f.operations[op] = Operation{
93 name: op,
94 identifier: Path,
95 hdlr: hdlr,
96 matcher: path,
97 cb: cb,
98 }
99 }
100
101 func (f *Factory) Intercept(req *http.Request, body []byte, auditLog bool) (rq *http.Request, res []byte, err error) {
102 ops := make(map[string]Operation, 0)
103 maps.Copy(ops, f.operations)
104 delete(ops, "default")
105 username := ""
106 authProvider := ""
107 organization := ""
108 if f.session != nil {
109 usernameField := f.session.Get(types.SessionUsernameField)
110 if usr, ok := usernameField.(string); ok {
111 username = usr
112 }
113 authTypeField := f.session.Get(types.SessionAuthTypeField)
114 if authType, ok := authTypeField.(string); ok {
115 authProvider = authType
116 }
117 organizationTypeField := f.session.Get(types.SessionOrganizationField)
118 if organizationType, ok := organizationTypeField.(string); ok {
119 organization = organizationType
120 }
121 }
122 var opts []audit.Option
123 if req != nil && req.Body != nil {
124 requestBody, _ := io.ReadAll(req.Body)
125 req.Body = io.NopCloser(bytes.NewBuffer(requestBody))
126 requestQuery := &RequestQuery{}
127 _ = json.Unmarshal(requestBody, requestQuery)
128 schema, _ := graphqlhelpers.ParseQuery(requestQuery.Query)
129 graphqlhelpers.SanitizeDocument(schema)
130 opctx := &graphql.OperationContext{Variables: requestQuery.Variables}
131 variables := graphqlhelpers.GetVariables(opctx)
132 graphqlhelpers.UpdateQueryWithVariables(schema, variables)
133 buf := bytes.NewBuffer(nil)
134 formatter.NewFormatter(buf).FormatQueryDocument(schema)
135 gResponse := &graphql.Response{}
136 _ = json.Unmarshal(body, gResponse)
137 opts = []audit.Option{
138 audit.WithStatus(graphqlhelpers.GetResponseStatus(gResponse)),
139 audit.WithUserIP(f.ip),
140 audit.WithUserAgent(req.UserAgent()),
141 audit.WithRequestURL(req.URL.String()),
142 audit.WithMethod(req.Method),
143 audit.WithIdentifier(f.correlationID),
144 audit.WithOperationName(graphqlhelpers.GetOperations(schema)),
145 audit.WithActor(username),
146 audit.WithAuthProvider(authProvider),
147 audit.WithTenant(bsl.GetOrgShortName(organization)),
148 audit.WithInput(buf.String()),
149 audit.WithParameters(graphqlhelpers.GetParams(opctx, schema)),
150 }
151 }
152 defer func() {
153 if auditLog {
154 f.auditor.Log(opts...)
155 }
156 }()
157 for _, operation := range ops {
158 re := regexp.MustCompile(operation.matcher)
159 matchStr := string(body)
160 if operation.identifier == Path {
161 matchStr = req.URL.Path
162 }
163 if re.MatchString(matchStr) && operation.cb {
164 rq, res, err = operation.hdlr(req, body)
165 return
166 }
167 }
168 rq, res, err = f.operations["default"].hdlr(req, body)
169 return
170 }
171
View as plain text