package interceptor import ( "bytes" "encoding/json" "errors" "io" "net/http" "regexp" "maps" "github.com/99designs/gqlgen/graphql" "github.com/gin-contrib/sessions" "github.com/vektah/gqlparser/v2/formatter" "edge-infra.dev/pkg/edge/api/graphqlhelpers" "edge-infra.dev/pkg/edge/audit" "edge-infra.dev/pkg/edge/auth-proxy/types" "edge-infra.dev/pkg/edge/bsl" ) type Factory struct { operations map[string]Operation auditor *audit.Sink session sessions.Session correlationID string ip string } type Operation struct { name string identifier op hdlr func(*http.Request, []byte) (*http.Request, []byte, error) matcher string cb bool } type RequestQuery struct { Query string `json:"query"` Variables map[string]interface{} `json:"variables"` } var ( ErrInterceptorNotImplemented = errors.New("interceptor not implemented") ) func New(ip, correlationID string, sess sessions.Session) *Factory { return &Factory{ operations: make(map[string]Operation, 0), auditor: audit.New("auth-proxy"), session: sess, correlationID: correlationID, ip: ip, } } func (f *Factory) Query(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), matcher string, cb bool) { f.operations[op] = Operation{ name: op, identifier: Query, hdlr: hdlr, matcher: matcher, cb: cb, } } func (f *Factory) Mutation(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), matcher string, cb bool) { f.operations[op] = Operation{ name: op, identifier: Query, hdlr: hdlr, matcher: matcher, cb: cb, } } func (f *Factory) Subscription(_ string, _ func(*http.Request, []byte) (*http.Request, []byte, error), _ string, _ bool) error { return ErrInterceptorNotImplemented } func (f *Factory) Default(hdlr func(*http.Request, []byte) (*http.Request, []byte, error), cb bool) { f.operations["default"] = Operation{ identifier: Default, hdlr: hdlr, cb: cb, } } // Method to create an operation that looks for a match with the request URL before handling func (f *Factory) Path(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), path string, cb bool) { f.operations[op] = Operation{ name: op, identifier: Path, hdlr: hdlr, matcher: path, cb: cb, } } func (f *Factory) Intercept(req *http.Request, body []byte, auditLog bool) (rq *http.Request, res []byte, err error) { ops := make(map[string]Operation, 0) maps.Copy(ops, f.operations) delete(ops, "default") username := "" authProvider := "" organization := "" if f.session != nil { usernameField := f.session.Get(types.SessionUsernameField) if usr, ok := usernameField.(string); ok { username = usr } authTypeField := f.session.Get(types.SessionAuthTypeField) if authType, ok := authTypeField.(string); ok { authProvider = authType } organizationTypeField := f.session.Get(types.SessionOrganizationField) if organizationType, ok := organizationTypeField.(string); ok { organization = organizationType } } var opts []audit.Option if req != nil && req.Body != nil { requestBody, _ := io.ReadAll(req.Body) req.Body = io.NopCloser(bytes.NewBuffer(requestBody)) requestQuery := &RequestQuery{} _ = json.Unmarshal(requestBody, requestQuery) schema, _ := graphqlhelpers.ParseQuery(requestQuery.Query) graphqlhelpers.SanitizeDocument(schema) opctx := &graphql.OperationContext{Variables: requestQuery.Variables} variables := graphqlhelpers.GetVariables(opctx) graphqlhelpers.UpdateQueryWithVariables(schema, variables) buf := bytes.NewBuffer(nil) formatter.NewFormatter(buf).FormatQueryDocument(schema) gResponse := &graphql.Response{} _ = json.Unmarshal(body, gResponse) opts = []audit.Option{ audit.WithStatus(graphqlhelpers.GetResponseStatus(gResponse)), audit.WithUserIP(f.ip), audit.WithUserAgent(req.UserAgent()), audit.WithRequestURL(req.URL.String()), audit.WithMethod(req.Method), audit.WithIdentifier(f.correlationID), audit.WithOperationName(graphqlhelpers.GetOperations(schema)), audit.WithActor(username), audit.WithAuthProvider(authProvider), audit.WithTenant(bsl.GetOrgShortName(organization)), audit.WithInput(buf.String()), audit.WithParameters(graphqlhelpers.GetParams(opctx, schema)), } } defer func() { if auditLog { f.auditor.Log(opts...) } }() for _, operation := range ops { re := regexp.MustCompile(operation.matcher) matchStr := string(body) if operation.identifier == Path { matchStr = req.URL.Path } if re.MatchString(matchStr) && operation.cb { rq, res, err = operation.hdlr(req, body) return } } rq, res, err = f.operations["default"].hdlr(req, body) return }