package edgeencrypt import ( "context" "crypto" "crypto/rsa" "crypto/sha256" "fmt" "hash/crc32" "github.com/golang-jwt/jwt" "google.golang.org/protobuf/types/known/wrapperspb" kms "cloud.google.com/go/kms/apiv1" "cloud.google.com/go/kms/apiv1/kmspb" ) type KmsKey struct { ProjectID string // PROJECT_ID Location string // LOCATION gcp_region } // KeyPath formats the kms key path for encryption // projects/PROJECT_ID/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY_NAME/cryptoKeyVersions/KEY_VERSION // from bearer token: bannerEdgeId -> KEY_RING, channelID -> KEY_NAME, KeyVersion -> KEY_VERSION func (c KmsKey) KeyPath(bannerEdgeID, channelID, keyVersion string) string { return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", c.ProjectID, c.Location, bannerEdgeID, channelID, keyVersion) } func (c KmsKey) RingParent() string { return fmt.Sprintf("projects/%s/locations/%s", c.ProjectID, c.Location) } func (c KmsKey) Ring(ring string) string { return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s", c.ProjectID, c.Location, ring) } func (c KmsKey) Key(ring, key string) string { return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.ProjectID, c.Location, ring, key) } var ( _ jwt.SigningMethod = (*SigningMethodKMS)(nil) ) func init() { jwt.RegisterSigningMethod("KMS_RS256", func() jwt.SigningMethod { return &SigningMethodKMS{} }) } type SigningMethodKMS struct { Client *kms.KeyManagementClient } func NewSigningMethodKMS(client *kms.KeyManagementClient) *SigningMethodKMS { return &SigningMethodKMS{Client: client} } func (k SigningMethodKMS) Alg() string { return "KMS_RS256" } // Sign use during token creation, use KMS to sign the token func (k SigningMethodKMS) Sign(signingString string, key interface{}) (string, error) { keyPath, ok := key.(string) if !ok { return "", fmt.Errorf("kms key must be a valid string: invalid key type: %T", key) } return k.signWithKMS(context.Background(), signingString, keyPath) } // Verify use during token validation, it uses a public key, no need for KMS. used by `jwt.ParseWithClaims` func (k SigningMethodKMS) Verify(signingString, signature string, key interface{}) error { var pk *rsa.PublicKey var ok bool if pk, ok = key.(*rsa.PublicKey); !ok { return fmt.Errorf("validation key must rsa.PublicKey: invalid key type: %T", key) } sig, err := jwt.DecodeSegment(signature) if err != nil { return fmt.Errorf("failed to decode signature: %w", err) } digest := sha256.Sum256([]byte(signingString)) if err := rsa.VerifyPSS(pk, crypto.SHA256, digest[:], sig, &rsa.PSSOptions{ SaltLength: len(digest), Hash: crypto.SHA256, }); err != nil { return fmt.Errorf("failed to verify signature: %w", err) } return nil } func (k SigningMethodKMS) signWithKMS(ctx context.Context, signingString string, key string) (string, error) { plaintext := []byte(signingString) digest := sha256.New() if _, err := digest.Write(plaintext); err != nil { return "", fmt.Errorf("failed to create digest: %w", err) } digestCRC32C := crc32c(digest.Sum(nil)) req := &kmspb.AsymmetricSignRequest{ Name: key, Digest: &kmspb.Digest{ Digest: &kmspb.Digest_Sha256{ Sha256: digest.Sum(nil), }, }, DigestCrc32C: wrapperspb.Int64(int64(digestCRC32C)), } result, err := k.Client.AsymmetricSign(ctx, req) if err != nil { return "", fmt.Errorf("failed to sign data: %w", err) } if !result.VerifiedDigestCrc32C { return "", fmt.Errorf("AsymmetricSign: request corrupted in-transit") } if result.Name != req.Name { return "", fmt.Errorf("AsymmetricSign: request corrupted in-transit") } if int64(crc32c(result.Signature)) != result.SignatureCrc32C.Value { return "", fmt.Errorf("AsymmetricSign: response corrupted in-transit") } return jwt.EncodeSegment(result.Signature), nil } // Compute digest's CRC32C. func crc32c(data []byte) uint32 { t := crc32.MakeTable(crc32.Castagnoli) return crc32.Checksum(data, t) }