1 package edgeencrypt
2
3 import (
4 "context"
5 "crypto"
6 "crypto/rsa"
7 "crypto/sha256"
8 "fmt"
9 "hash/crc32"
10
11 "github.com/golang-jwt/jwt"
12 "google.golang.org/protobuf/types/known/wrapperspb"
13
14 kms "cloud.google.com/go/kms/apiv1"
15 "cloud.google.com/go/kms/apiv1/kmspb"
16 )
17
18 type KmsKey struct {
19 ProjectID string
20 Location string
21 }
22
23
24
25
26 func (c KmsKey) KeyPath(bannerEdgeID, channelID, keyVersion string) string {
27 return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s",
28 c.ProjectID,
29 c.Location,
30 bannerEdgeID,
31 channelID,
32 keyVersion)
33 }
34
35 func (c KmsKey) RingParent() string {
36 return fmt.Sprintf("projects/%s/locations/%s", c.ProjectID, c.Location)
37 }
38
39 func (c KmsKey) Ring(ring string) string {
40 return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s", c.ProjectID, c.Location, ring)
41 }
42
43 func (c KmsKey) Key(ring, key string) string {
44 return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.ProjectID, c.Location, ring, key)
45 }
46
47 var (
48 _ jwt.SigningMethod = (*SigningMethodKMS)(nil)
49 )
50
51 func init() {
52 jwt.RegisterSigningMethod("KMS_RS256", func() jwt.SigningMethod {
53 return &SigningMethodKMS{}
54 })
55 }
56
57 type SigningMethodKMS struct {
58 Client *kms.KeyManagementClient
59 }
60
61 func NewSigningMethodKMS(client *kms.KeyManagementClient) *SigningMethodKMS {
62 return &SigningMethodKMS{Client: client}
63 }
64
65 func (k SigningMethodKMS) Alg() string {
66 return "KMS_RS256"
67 }
68
69
70 func (k SigningMethodKMS) Sign(signingString string, key interface{}) (string, error) {
71 keyPath, ok := key.(string)
72 if !ok {
73 return "", fmt.Errorf("kms key must be a valid string: invalid key type: %T", key)
74 }
75 return k.signWithKMS(context.Background(), signingString, keyPath)
76 }
77
78
79 func (k SigningMethodKMS) Verify(signingString, signature string, key interface{}) error {
80 var pk *rsa.PublicKey
81 var ok bool
82 if pk, ok = key.(*rsa.PublicKey); !ok {
83 return fmt.Errorf("validation key must rsa.PublicKey: invalid key type: %T", key)
84 }
85 sig, err := jwt.DecodeSegment(signature)
86 if err != nil {
87 return fmt.Errorf("failed to decode signature: %w", err)
88 }
89 digest := sha256.Sum256([]byte(signingString))
90 if err := rsa.VerifyPSS(pk, crypto.SHA256, digest[:], sig, &rsa.PSSOptions{
91 SaltLength: len(digest),
92 Hash: crypto.SHA256,
93 }); err != nil {
94 return fmt.Errorf("failed to verify signature: %w", err)
95 }
96 return nil
97 }
98
99 func (k SigningMethodKMS) signWithKMS(ctx context.Context, signingString string, key string) (string, error) {
100 plaintext := []byte(signingString)
101 digest := sha256.New()
102 if _, err := digest.Write(plaintext); err != nil {
103 return "", fmt.Errorf("failed to create digest: %w", err)
104 }
105
106 digestCRC32C := crc32c(digest.Sum(nil))
107
108 req := &kmspb.AsymmetricSignRequest{
109 Name: key,
110 Digest: &kmspb.Digest{
111 Digest: &kmspb.Digest_Sha256{
112 Sha256: digest.Sum(nil),
113 },
114 },
115 DigestCrc32C: wrapperspb.Int64(int64(digestCRC32C)),
116 }
117
118 result, err := k.Client.AsymmetricSign(ctx, req)
119 if err != nil {
120 return "", fmt.Errorf("failed to sign data: %w", err)
121 }
122
123 if !result.VerifiedDigestCrc32C {
124 return "", fmt.Errorf("AsymmetricSign: request corrupted in-transit")
125 }
126 if result.Name != req.Name {
127 return "", fmt.Errorf("AsymmetricSign: request corrupted in-transit")
128 }
129 if int64(crc32c(result.Signature)) != result.SignatureCrc32C.Value {
130 return "", fmt.Errorf("AsymmetricSign: response corrupted in-transit")
131 }
132 return jwt.EncodeSegment(result.Signature), nil
133 }
134
135
136 func crc32c(data []byte) uint32 {
137 t := crc32.MakeTable(crc32.Castagnoli)
138 return crc32.Checksum(data, t)
139 }
140
View as plain text