...

Source file src/edge-infra.dev/pkg/edge/auth-proxy/interceptor/interceptor.go

Documentation: edge-infra.dev/pkg/edge/auth-proxy/interceptor

     1  package interceptor
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"regexp"
    10  
    11  	"maps"
    12  
    13  	"github.com/99designs/gqlgen/graphql"
    14  	"github.com/gin-contrib/sessions"
    15  	"github.com/vektah/gqlparser/v2/formatter"
    16  
    17  	"edge-infra.dev/pkg/edge/api/graphqlhelpers"
    18  	"edge-infra.dev/pkg/edge/audit"
    19  	"edge-infra.dev/pkg/edge/auth-proxy/types"
    20  	"edge-infra.dev/pkg/edge/bsl"
    21  )
    22  
    23  type Factory struct {
    24  	operations    map[string]Operation
    25  	auditor       *audit.Sink
    26  	session       sessions.Session
    27  	correlationID string
    28  	ip            string
    29  }
    30  
    31  type Operation struct {
    32  	name       string
    33  	identifier op
    34  	hdlr       func(*http.Request, []byte) (*http.Request, []byte, error)
    35  	matcher    string
    36  	cb         bool
    37  }
    38  
    39  type RequestQuery struct {
    40  	Query     string                 `json:"query"`
    41  	Variables map[string]interface{} `json:"variables"`
    42  }
    43  
    44  var (
    45  	ErrInterceptorNotImplemented = errors.New("interceptor not implemented")
    46  )
    47  
    48  func New(ip, correlationID string, sess sessions.Session) *Factory {
    49  	return &Factory{
    50  		operations:    make(map[string]Operation, 0),
    51  		auditor:       audit.New("auth-proxy"),
    52  		session:       sess,
    53  		correlationID: correlationID,
    54  		ip:            ip,
    55  	}
    56  }
    57  
    58  func (f *Factory) Query(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), matcher string, cb bool) {
    59  	f.operations[op] = Operation{
    60  		name:       op,
    61  		identifier: Query,
    62  		hdlr:       hdlr,
    63  		matcher:    matcher,
    64  		cb:         cb,
    65  	}
    66  }
    67  
    68  func (f *Factory) Mutation(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), matcher string, cb bool) {
    69  	f.operations[op] = Operation{
    70  		name:       op,
    71  		identifier: Query,
    72  		hdlr:       hdlr,
    73  		matcher:    matcher,
    74  		cb:         cb,
    75  	}
    76  }
    77  
    78  func (f *Factory) Subscription(_ string, _ func(*http.Request, []byte) (*http.Request, []byte, error), _ string, _ bool) error {
    79  	return ErrInterceptorNotImplemented
    80  }
    81  
    82  func (f *Factory) Default(hdlr func(*http.Request, []byte) (*http.Request, []byte, error), cb bool) {
    83  	f.operations["default"] = Operation{
    84  		identifier: Default,
    85  		hdlr:       hdlr,
    86  		cb:         cb,
    87  	}
    88  }
    89  
    90  // Method to create an operation that looks for a match with the request URL before handling
    91  func (f *Factory) Path(op string, hdlr func(*http.Request, []byte) (*http.Request, []byte, error), path string, cb bool) {
    92  	f.operations[op] = Operation{
    93  		name:       op,
    94  		identifier: Path,
    95  		hdlr:       hdlr,
    96  		matcher:    path,
    97  		cb:         cb,
    98  	}
    99  }
   100  
   101  func (f *Factory) Intercept(req *http.Request, body []byte, auditLog bool) (rq *http.Request, res []byte, err error) {
   102  	ops := make(map[string]Operation, 0)
   103  	maps.Copy(ops, f.operations)
   104  	delete(ops, "default")
   105  	username := ""
   106  	authProvider := ""
   107  	organization := ""
   108  	if f.session != nil {
   109  		usernameField := f.session.Get(types.SessionUsernameField)
   110  		if usr, ok := usernameField.(string); ok {
   111  			username = usr
   112  		}
   113  		authTypeField := f.session.Get(types.SessionAuthTypeField)
   114  		if authType, ok := authTypeField.(string); ok {
   115  			authProvider = authType
   116  		}
   117  		organizationTypeField := f.session.Get(types.SessionOrganizationField)
   118  		if organizationType, ok := organizationTypeField.(string); ok {
   119  			organization = organizationType
   120  		}
   121  	}
   122  	var opts []audit.Option
   123  	if req != nil && req.Body != nil {
   124  		requestBody, _ := io.ReadAll(req.Body)
   125  		req.Body = io.NopCloser(bytes.NewBuffer(requestBody))
   126  		requestQuery := &RequestQuery{}
   127  		_ = json.Unmarshal(requestBody, requestQuery)
   128  		schema, _ := graphqlhelpers.ParseQuery(requestQuery.Query)
   129  		graphqlhelpers.SanitizeDocument(schema)
   130  		opctx := &graphql.OperationContext{Variables: requestQuery.Variables}
   131  		variables := graphqlhelpers.GetVariables(opctx)
   132  		graphqlhelpers.UpdateQueryWithVariables(schema, variables)
   133  		buf := bytes.NewBuffer(nil)
   134  		formatter.NewFormatter(buf).FormatQueryDocument(schema)
   135  		gResponse := &graphql.Response{}
   136  		_ = json.Unmarshal(body, gResponse)
   137  		opts = []audit.Option{
   138  			audit.WithStatus(graphqlhelpers.GetResponseStatus(gResponse)),
   139  			audit.WithUserIP(f.ip),
   140  			audit.WithUserAgent(req.UserAgent()),
   141  			audit.WithRequestURL(req.URL.String()),
   142  			audit.WithMethod(req.Method),
   143  			audit.WithIdentifier(f.correlationID),
   144  			audit.WithOperationName(graphqlhelpers.GetOperations(schema)),
   145  			audit.WithActor(username),
   146  			audit.WithAuthProvider(authProvider),
   147  			audit.WithTenant(bsl.GetOrgShortName(organization)),
   148  			audit.WithInput(buf.String()),
   149  			audit.WithParameters(graphqlhelpers.GetParams(opctx, schema)),
   150  		}
   151  	}
   152  	defer func() {
   153  		if auditLog {
   154  			f.auditor.Log(opts...)
   155  		}
   156  	}()
   157  	for _, operation := range ops {
   158  		re := regexp.MustCompile(operation.matcher)
   159  		matchStr := string(body)
   160  		if operation.identifier == Path {
   161  			matchStr = req.URL.Path
   162  		}
   163  		if re.MatchString(matchStr) && operation.cb {
   164  			rq, res, err = operation.hdlr(req, body)
   165  			return
   166  		}
   167  	}
   168  	rq, res, err = f.operations["default"].hdlr(req, body)
   169  	return
   170  }
   171  

View as plain text