package jws import ( "bytes" "context" "github.com/lestrrat-go/jwx/internal/base64" "github.com/lestrrat-go/jwx/internal/json" "github.com/lestrrat-go/jwx/internal/pool" "github.com/lestrrat-go/jwx/jwk" "github.com/pkg/errors" ) type collectRawCtx struct{} func (collectRawCtx) CollectRaw() bool { return true } func NewSignature() *Signature { return &Signature{} } func (s *Signature) DecodeCtx() DecodeCtx { return s.dc } func (s *Signature) SetDecodeCtx(dc DecodeCtx) { s.dc = dc } func (s Signature) PublicHeaders() Headers { return s.headers } func (s *Signature) SetPublicHeaders(v Headers) *Signature { s.headers = v return s } func (s Signature) ProtectedHeaders() Headers { return s.protected } func (s *Signature) SetProtectedHeaders(v Headers) *Signature { s.protected = v return s } func (s Signature) Signature() []byte { return s.signature } func (s *Signature) SetSignature(v []byte) *Signature { s.signature = v return s } type signatureUnmarshalProbe struct { Header Headers `json:"header,omitempty"` Protected *string `json:"protected,omitempty"` Signature *string `json:"signature,omitempty"` } func (s *Signature) UnmarshalJSON(data []byte) error { var sup signatureUnmarshalProbe sup.Header = NewHeaders() if err := json.Unmarshal(data, &sup); err != nil { return errors.Wrap(err, `failed to unmarshal signature into temporary struct`) } s.headers = sup.Header if buf := sup.Protected; buf != nil { src := []byte(*buf) if !bytes.HasPrefix(src, []byte{'{'}) { decoded, err := base64.Decode(src) if err != nil { return errors.Wrap(err, `failed to base64 decode protected headers`) } src = decoded } prt := NewHeaders() //nolint:forcetypeassert prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx()) if err := json.Unmarshal(src, prt); err != nil { return errors.Wrap(err, `failed to unmarshal protected headers`) } //nolint:forcetypeassert prt.(*stdHeaders).SetDecodeCtx(nil) s.protected = prt } decoded, err := base64.DecodeString(*sup.Signature) if err != nil { return errors.Wrap(err, `failed to base decode signature`) } s.signature = decoded return nil } // Sign populates the signature field, with a signature generated by // given the signer object and payload. // // The first return value is the raw signature in binary format. // The second return value s the full three-segment signature // (e.g. "eyXXXX.XXXXX.XXXX") func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() hdrs, err := mergeHeaders(ctx, s.headers, s.protected) if err != nil { return nil, nil, errors.Wrap(err, `failed to merge headers`) } if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil { return nil, nil, errors.Wrap(err, `failed to set "alg"`) } // If the key is a jwk.Key instance, obtain the raw key if jwkKey, ok := key.(jwk.Key); ok { // If we have a key ID specified by this jwk.Key, use that in the header if kid := jwkKey.KeyID(); kid != "" { if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil { return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`) } } } hdrbuf, err := json.Marshal(hdrs) if err != nil { return nil, nil, errors.Wrap(err, `failed to marshal headers`) } buf := pool.GetBytesBuffer() defer pool.ReleaseBytesBuffer(buf) buf.WriteString(base64.EncodeToString(hdrbuf)) buf.WriteByte('.') var plen int b64 := getB64Value(hdrs) if b64 { encoded := base64.EncodeToString(payload) plen = len(encoded) buf.WriteString(encoded) } else { if !s.detached { if bytes.Contains(payload, []byte{'.'}) { return nil, nil, errors.New(`payload must not contain a "."`) } } plen = len(payload) buf.Write(payload) } signature, err := signer.Sign(buf.Bytes(), key) if err != nil { return nil, nil, errors.Wrap(err, `failed to sign payload`) } s.signature = signature // Detached payload, this should be removed from the end result if s.detached { buf.Truncate(buf.Len() - plen) } buf.WriteByte('.') buf.WriteString(base64.EncodeToString(signature)) ret := make([]byte, buf.Len()) copy(ret, buf.Bytes()) return signature, ret, nil } func NewMessage() *Message { return &Message{} } // Clears the internal raw buffer that was accumulated during // the verify phase func (m *Message) clearRaw() { for _, sig := range m.signatures { if protected := sig.protected; protected != nil { if cr, ok := protected.(*stdHeaders); ok { cr.raw = nil } } } } func (m *Message) SetDecodeCtx(dc DecodeCtx) { m.dc = dc } func (m *Message) DecodeCtx() DecodeCtx { return m.dc } // Payload returns the decoded payload func (m Message) Payload() []byte { return m.payload } func (m *Message) SetPayload(v []byte) *Message { m.payload = v return m } func (m Message) Signatures() []*Signature { return m.signatures } func (m *Message) AppendSignature(v *Signature) *Message { m.signatures = append(m.signatures, v) return m } func (m *Message) ClearSignatures() *Message { m.signatures = nil return m } // LookupSignature looks up a particular signature entry using // the `kid` value func (m Message) LookupSignature(kid string) []*Signature { var sigs []*Signature for _, sig := range m.signatures { if hdr := sig.PublicHeaders(); hdr != nil { hdrKeyID := hdr.KeyID() if hdrKeyID == kid { sigs = append(sigs, sig) continue } } if hdr := sig.ProtectedHeaders(); hdr != nil { hdrKeyID := hdr.KeyID() if hdrKeyID == kid { sigs = append(sigs, sig) continue } } } return sigs } // This struct is used to first probe for the structure of the // incoming JSON object. We then decide how to parse it // from the fields that are populated. type messageUnmarshalProbe struct { Payload *string `json:"payload"` Signatures []json.RawMessage `json:"signatures,omitempty"` Header Headers `json:"header,omitempty"` Protected *string `json:"protected,omitempty"` Signature *string `json:"signature,omitempty"` } func (m *Message) UnmarshalJSON(buf []byte) error { m.payload = nil m.signatures = nil m.b64 = true var mup messageUnmarshalProbe mup.Header = NewHeaders() if err := json.Unmarshal(buf, &mup); err != nil { return errors.Wrap(err, `failed to unmarshal into temporary structure`) } b64 := true if mup.Signature == nil { // flattened signature is NOT present if len(mup.Signatures) == 0 { return errors.New(`required field "signatures" not present`) } m.signatures = make([]*Signature, 0, len(mup.Signatures)) for i, rawsig := range mup.Signatures { var sig Signature sig.SetDecodeCtx(m.DecodeCtx()) if err := json.Unmarshal(rawsig, &sig); err != nil { return errors.Wrapf(err, `failed to unmarshal signature #%d`, i+1) } sig.SetDecodeCtx(nil) if i == 0 { if !getB64Value(sig.protected) { b64 = false } } else { if b64 != getB64Value(sig.protected) { return errors.Errorf(`b64 value must be the same for all signatures`) } } m.signatures = append(m.signatures, &sig) } } else { // .signature is present, it's a flattened structure if len(mup.Signatures) != 0 { return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`) } var sig Signature sig.headers = mup.Header if src := mup.Protected; src != nil { decoded, err := base64.DecodeString(*src) if err != nil { return errors.Wrap(err, `failed to base64 decode flattened protected headers`) } prt := NewHeaders() //nolint:forcetypeassert prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx()) if err := json.Unmarshal(decoded, prt); err != nil { return errors.Wrap(err, `failed to unmarshal flattened protected headers`) } //nolint:forcetypeassert prt.(*stdHeaders).SetDecodeCtx(nil) sig.protected = prt } decoded, err := base64.DecodeString(*mup.Signature) if err != nil { return errors.Wrap(err, `failed to base64 decode flattened signature`) } sig.signature = decoded m.signatures = []*Signature{&sig} b64 = getB64Value(sig.protected) } if mup.Payload != nil { if !b64 { // NOT base64 encoded m.payload = []byte(*mup.Payload) } else { decoded, err := base64.DecodeString(*mup.Payload) if err != nil { return errors.Wrap(err, `failed to base64 decode payload`) } m.payload = decoded } } m.b64 = b64 return nil } func (m Message) MarshalJSON() ([]byte, error) { if len(m.signatures) == 1 { return m.marshalFlattened() } return m.marshalFull() } func (m Message) marshalFlattened() ([]byte, error) { buf := pool.GetBytesBuffer() defer pool.ReleaseBytesBuffer(buf) sig := m.signatures[0] buf.WriteRune('{') var wrote bool if hdr := sig.headers; hdr != nil { hdrjs, err := hdr.MarshalJSON() if err != nil { return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`) } buf.WriteString(`"header":`) buf.Write(hdrjs) wrote = true } if wrote { buf.WriteRune(',') } buf.WriteString(`"payload":"`) buf.WriteString(base64.EncodeToString(m.payload)) buf.WriteRune('"') if protected := sig.protected; protected != nil { protectedbuf, err := protected.MarshalJSON() if err != nil { return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`) } buf.WriteString(`,"protected":"`) buf.WriteString(base64.EncodeToString(protectedbuf)) buf.WriteRune('"') } buf.WriteString(`,"signature":"`) buf.WriteString(base64.EncodeToString(sig.signature)) buf.WriteRune('"') buf.WriteRune('}') ret := make([]byte, buf.Len()) copy(ret, buf.Bytes()) return ret, nil } func (m Message) marshalFull() ([]byte, error) { buf := pool.GetBytesBuffer() defer pool.ReleaseBytesBuffer(buf) buf.WriteString(`{"payload":"`) buf.WriteString(base64.EncodeToString(m.payload)) buf.WriteString(`","signatures":[`) for i, sig := range m.signatures { if i > 0 { buf.WriteRune(',') } buf.WriteRune('{') var wrote bool if hdr := sig.headers; hdr != nil { hdrbuf, err := hdr.MarshalJSON() if err != nil { return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1) } buf.WriteString(`"header":`) buf.Write(hdrbuf) wrote = true } if protected := sig.protected; protected != nil { protectedbuf, err := protected.MarshalJSON() if err != nil { return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1) } if wrote { buf.WriteRune(',') } buf.WriteString(`"protected":"`) buf.WriteString(base64.EncodeToString(protectedbuf)) buf.WriteRune('"') wrote = true } if wrote { buf.WriteRune(',') } buf.WriteString(`"signature":"`) buf.WriteString(base64.EncodeToString(sig.signature)) buf.WriteString(`"}`) } buf.WriteString(`]}`) ret := make([]byte, buf.Len()) copy(ret, buf.Bytes()) return ret, nil }