...

Source file src/github.com/99designs/gqlgen/graphql/context_response.go

Documentation: github.com/99designs/gqlgen/graphql

     1  package graphql
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/vektah/gqlparser/v2/gqlerror"
     9  )
    10  
    11  type responseContext struct {
    12  	errorPresenter ErrorPresenterFunc
    13  	recover        RecoverFunc
    14  
    15  	errors   gqlerror.List
    16  	errorsMu sync.Mutex
    17  
    18  	extensions   map[string]interface{}
    19  	extensionsMu sync.Mutex
    20  }
    21  
    22  const resultCtx key = "result_context"
    23  
    24  func getResponseContext(ctx context.Context) *responseContext {
    25  	val, ok := ctx.Value(resultCtx).(*responseContext)
    26  	if !ok {
    27  		panic("missing response context")
    28  	}
    29  	return val
    30  }
    31  
    32  func WithResponseContext(ctx context.Context, presenterFunc ErrorPresenterFunc, recoverFunc RecoverFunc) context.Context {
    33  	return context.WithValue(ctx, resultCtx, &responseContext{
    34  		errorPresenter: presenterFunc,
    35  		recover:        recoverFunc,
    36  	})
    37  }
    38  
    39  func WithFreshResponseContext(ctx context.Context) context.Context {
    40  	e := getResponseContext(ctx)
    41  	return context.WithValue(ctx, resultCtx, &responseContext{
    42  		errorPresenter: e.errorPresenter,
    43  		recover:        e.recover,
    44  	})
    45  }
    46  
    47  // AddErrorf writes a formatted error to the client, first passing it through the error presenter.
    48  func AddErrorf(ctx context.Context, format string, args ...interface{}) {
    49  	AddError(ctx, fmt.Errorf(format, args...))
    50  }
    51  
    52  // AddError sends an error to the client, first passing it through the error presenter.
    53  func AddError(ctx context.Context, err error) {
    54  	c := getResponseContext(ctx)
    55  
    56  	presentedError := c.errorPresenter(ctx, ErrorOnPath(ctx, err))
    57  
    58  	c.errorsMu.Lock()
    59  	defer c.errorsMu.Unlock()
    60  	c.errors = append(c.errors, presentedError)
    61  }
    62  
    63  func Recover(ctx context.Context, err interface{}) (userMessage error) {
    64  	c := getResponseContext(ctx)
    65  	return ErrorOnPath(ctx, c.recover(ctx, err))
    66  }
    67  
    68  // HasFieldError returns true if the given field has already errored
    69  func HasFieldError(ctx context.Context, rctx *FieldContext) bool {
    70  	c := getResponseContext(ctx)
    71  
    72  	c.errorsMu.Lock()
    73  	defer c.errorsMu.Unlock()
    74  
    75  	if len(c.errors) == 0 {
    76  		return false
    77  	}
    78  
    79  	path := rctx.Path()
    80  	for _, err := range c.errors {
    81  		if equalPath(err.Path, path) {
    82  			return true
    83  		}
    84  	}
    85  	return false
    86  }
    87  
    88  // GetFieldErrors returns a list of errors that occurred in the given field
    89  func GetFieldErrors(ctx context.Context, rctx *FieldContext) gqlerror.List {
    90  	c := getResponseContext(ctx)
    91  
    92  	c.errorsMu.Lock()
    93  	defer c.errorsMu.Unlock()
    94  
    95  	if len(c.errors) == 0 {
    96  		return nil
    97  	}
    98  
    99  	path := rctx.Path()
   100  	var errs gqlerror.List
   101  	for _, err := range c.errors {
   102  		if equalPath(err.Path, path) {
   103  			errs = append(errs, err)
   104  		}
   105  	}
   106  	return errs
   107  }
   108  
   109  func GetErrors(ctx context.Context) gqlerror.List {
   110  	resCtx := getResponseContext(ctx)
   111  	resCtx.errorsMu.Lock()
   112  	defer resCtx.errorsMu.Unlock()
   113  
   114  	if len(resCtx.errors) == 0 {
   115  		return nil
   116  	}
   117  
   118  	errs := resCtx.errors
   119  	cpy := make(gqlerror.List, len(errs))
   120  	for i := range errs {
   121  		errCpy := *errs[i]
   122  		cpy[i] = &errCpy
   123  	}
   124  	return cpy
   125  }
   126  
   127  // RegisterExtension allows you to add a new extension into the graphql response
   128  func RegisterExtension(ctx context.Context, key string, value interface{}) {
   129  	c := getResponseContext(ctx)
   130  	c.extensionsMu.Lock()
   131  	defer c.extensionsMu.Unlock()
   132  
   133  	if c.extensions == nil {
   134  		c.extensions = make(map[string]interface{})
   135  	}
   136  
   137  	if _, ok := c.extensions[key]; ok {
   138  		panic(fmt.Errorf("extension already registered for key %s", key))
   139  	}
   140  
   141  	c.extensions[key] = value
   142  }
   143  
   144  // GetExtensions returns any extensions registered in the current result context
   145  func GetExtensions(ctx context.Context) map[string]interface{} {
   146  	ext := getResponseContext(ctx).extensions
   147  	if ext == nil {
   148  		return map[string]interface{}{}
   149  	}
   150  
   151  	return ext
   152  }
   153  
   154  func GetExtension(ctx context.Context, name string) interface{} {
   155  	ext := getResponseContext(ctx).extensions
   156  	if ext == nil {
   157  		return nil
   158  	}
   159  
   160  	return ext[name]
   161  }
   162  

View as plain text