...

Source file src/github.com/lestrrat-go/jwx/jwt/serialize.go

Documentation: github.com/lestrrat-go/jwx/jwt

     1  package jwt
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/lestrrat-go/jwx/internal/json"
     7  	"github.com/lestrrat-go/jwx/jwa"
     8  	"github.com/lestrrat-go/jwx/jwe"
     9  	"github.com/lestrrat-go/jwx/jws"
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  type SerializeCtx interface {
    14  	Step() int
    15  	Nested() bool
    16  }
    17  
    18  type serializeCtx struct {
    19  	step   int
    20  	nested bool
    21  }
    22  
    23  func (ctx *serializeCtx) Step() int {
    24  	return ctx.step
    25  }
    26  
    27  func (ctx *serializeCtx) Nested() bool {
    28  	return ctx.nested
    29  }
    30  
    31  type SerializeStep interface {
    32  	Serialize(SerializeCtx, interface{}) (interface{}, error)
    33  }
    34  
    35  // Serializer is a generic serializer for JWTs. Whereas other conveinience
    36  // functions can only do one thing (such as generate a JWS signed JWT),
    37  // Using this construct you can serialize the token however you want.
    38  //
    39  // By default the serializer only marshals the token into a JSON payload.
    40  // You must set up the rest of the steps that should be taken by the
    41  // serializer.
    42  //
    43  // For example, to marshal the token into JSON, then apply JWS and JWE
    44  // in that order, you would do:
    45  //
    46  //   serialized, err := jwt.NewSerialer().
    47  //      Sign(jwa.RS256, key).
    48  //      Encrypt(jwa.RSA_OAEP, key.PublicKey).
    49  //      Serialize(token)
    50  //
    51  // The `jwt.Sign()` function is equivalent to
    52  //
    53  //   serialized, err := jwt.NewSerializer().
    54  //      Sign(...args...).
    55  //      Serialize(token)
    56  type Serializer struct {
    57  	steps []SerializeStep
    58  }
    59  
    60  // NewSerializer creates a new empty serializer.
    61  func NewSerializer() *Serializer {
    62  	return &Serializer{}
    63  }
    64  
    65  // Reset clears all of the registered steps.
    66  func (s *Serializer) Reset() *Serializer {
    67  	s.steps = nil
    68  	return s
    69  }
    70  
    71  // Step adds a new Step to the serialization process
    72  func (s *Serializer) Step(step SerializeStep) *Serializer {
    73  	s.steps = append(s.steps, step)
    74  	return s
    75  }
    76  
    77  type jsonSerializer struct{}
    78  
    79  func (jsonSerializer) Serialize(_ SerializeCtx, v interface{}) (interface{}, error) {
    80  	token, ok := v.(Token)
    81  	if !ok {
    82  		return nil, errors.Errorf(`invalid input: expected jwt.Token`)
    83  	}
    84  
    85  	buf, err := json.Marshal(token)
    86  	if err != nil {
    87  		return nil, errors.Errorf(`failed to serialize as JSON`)
    88  	}
    89  	return buf, nil
    90  }
    91  
    92  type genericHeader interface {
    93  	Get(string) (interface{}, bool)
    94  	Set(string, interface{}) error
    95  }
    96  
    97  func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error {
    98  	// cty and typ are common between JWE/JWS, so we don't use
    99  	// the constants in jws/jwe package here
   100  	const typKey = `typ`
   101  	const ctyKey = `cty`
   102  
   103  	if ctx.Step() == 1 {
   104  		// We are executed immediately after json marshaling
   105  		if _, ok := hdrs.Get(typKey); !ok {
   106  			if err := hdrs.Set(typKey, `JWT`); err != nil {
   107  				return errors.Wrapf(err, `failed to set %s key to "JWT"`, typKey)
   108  			}
   109  		}
   110  	} else {
   111  		if ctx.Nested() {
   112  			// If this is part of a nested sequence, we should set cty = 'JWT'
   113  			// https://datatracker.ietf.org/doc/html/rfc7519#section-5.2
   114  			if err := hdrs.Set(ctyKey, `JWT`); err != nil {
   115  				return errors.Wrapf(err, `failed to set %s key to "JWT"`, ctyKey)
   116  			}
   117  		}
   118  	}
   119  	return nil
   120  }
   121  
   122  type jwsSerializer struct {
   123  	alg     jwa.SignatureAlgorithm
   124  	key     interface{}
   125  	options []SignOption
   126  }
   127  
   128  func (s *jwsSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
   129  	payload, ok := v.([]byte)
   130  	if !ok {
   131  		return nil, errors.New(`expected []byte as input`)
   132  	}
   133  
   134  	var hdrs jws.Headers
   135  	//nolint:forcetypeassert
   136  	for _, option := range s.options {
   137  		switch option.Ident() {
   138  		case identJwsHeaders{}:
   139  			hdrs = option.Value().(jws.Headers)
   140  		}
   141  	}
   142  
   143  	if hdrs == nil {
   144  		hdrs = jws.NewHeaders()
   145  	}
   146  
   147  	if err := setTypeOrCty(ctx, hdrs); err != nil {
   148  		return nil, err // this is already wrapped
   149  	}
   150  
   151  	// JWTs MUST NOT use b64 = false
   152  	// https://datatracker.ietf.org/doc/html/rfc7797#section-7
   153  	if v, ok := hdrs.Get("b64"); ok {
   154  		if bval, bok := v.(bool); bok {
   155  			if !bval { // b64 = false
   156  				return nil, errors.New(`b64 cannot be false for JWTs`)
   157  			}
   158  		}
   159  	}
   160  	return jws.Sign(payload, s.alg, s.key, jws.WithHeaders(hdrs))
   161  }
   162  
   163  func (s *Serializer) Sign(alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) *Serializer {
   164  	return s.Step(&jwsSerializer{
   165  		alg:     alg,
   166  		key:     key,
   167  		options: options,
   168  	})
   169  }
   170  
   171  type jweSerializer struct {
   172  	keyalg      jwa.KeyEncryptionAlgorithm
   173  	key         interface{}
   174  	contentalg  jwa.ContentEncryptionAlgorithm
   175  	compressalg jwa.CompressionAlgorithm
   176  	options     []EncryptOption
   177  }
   178  
   179  func (s *jweSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
   180  	payload, ok := v.([]byte)
   181  	if !ok {
   182  		return nil, fmt.Errorf(`expected []byte as input`)
   183  	}
   184  
   185  	var hdrs jwe.Headers
   186  	//nolint:forcetypeassert
   187  	for _, option := range s.options {
   188  		switch option.Ident() {
   189  		case identJweHeaders{}:
   190  			hdrs = option.Value().(jwe.Headers)
   191  		}
   192  	}
   193  
   194  	if hdrs == nil {
   195  		hdrs = jwe.NewHeaders()
   196  	}
   197  
   198  	if err := setTypeOrCty(ctx, hdrs); err != nil {
   199  		return nil, err // this is already wrapped
   200  	}
   201  	return jwe.Encrypt(payload, s.keyalg, s.key, s.contentalg, s.compressalg, jwe.WithProtectedHeaders(hdrs))
   202  }
   203  
   204  func (s *Serializer) Encrypt(keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) *Serializer {
   205  	return s.Step(&jweSerializer{
   206  		keyalg:      keyalg,
   207  		key:         key,
   208  		contentalg:  contentalg,
   209  		compressalg: compressalg,
   210  		options:     options,
   211  	})
   212  }
   213  
   214  func (s *Serializer) Serialize(t Token) ([]byte, error) {
   215  	steps := make([]SerializeStep, len(s.steps)+1)
   216  	steps[0] = jsonSerializer{}
   217  	for i, step := range s.steps {
   218  		steps[i+1] = step
   219  	}
   220  
   221  	var ctx serializeCtx
   222  	ctx.nested = len(s.steps) > 1
   223  	var payload interface{} = t
   224  	for i, step := range steps {
   225  		ctx.step = i
   226  		v, err := step.Serialize(&ctx, payload)
   227  		if err != nil {
   228  			return nil, errors.Wrapf(err, `failed to serialize token at step #%d`, i+1)
   229  		}
   230  		payload = v
   231  	}
   232  
   233  	res, ok := payload.([]byte)
   234  	if !ok {
   235  		return nil, errors.New(`invalid serialization produced`)
   236  	}
   237  
   238  	return res, nil
   239  }
   240  

View as plain text