...

Source file src/edge-infra.dev/pkg/edge/edgeencrypt/kms.go

Documentation: edge-infra.dev/pkg/edge/edgeencrypt

     1  package edgeencrypt
     2  
     3  import (
     4  	"context"
     5  	"crypto"
     6  	"crypto/rsa"
     7  	"crypto/sha256"
     8  	"fmt"
     9  	"hash/crc32"
    10  
    11  	"github.com/golang-jwt/jwt"
    12  	"google.golang.org/protobuf/types/known/wrapperspb"
    13  
    14  	kms "cloud.google.com/go/kms/apiv1"
    15  	"cloud.google.com/go/kms/apiv1/kmspb"
    16  )
    17  
    18  type KmsKey struct {
    19  	ProjectID string // PROJECT_ID
    20  	Location  string // LOCATION gcp_region
    21  }
    22  
    23  // KeyPath formats the kms key path for encryption
    24  // projects/PROJECT_ID/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY_NAME/cryptoKeyVersions/KEY_VERSION
    25  // from bearer token: bannerEdgeId -> KEY_RING, channelID -> KEY_NAME, KeyVersion -> KEY_VERSION
    26  func (c KmsKey) KeyPath(bannerEdgeID, channelID, keyVersion string) string {
    27  	return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s",
    28  		c.ProjectID,
    29  		c.Location,
    30  		bannerEdgeID,
    31  		channelID,
    32  		keyVersion)
    33  }
    34  
    35  func (c KmsKey) RingParent() string {
    36  	return fmt.Sprintf("projects/%s/locations/%s", c.ProjectID, c.Location)
    37  }
    38  
    39  func (c KmsKey) Ring(ring string) string {
    40  	return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s", c.ProjectID, c.Location, ring)
    41  }
    42  
    43  func (c KmsKey) Key(ring, key string) string {
    44  	return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.ProjectID, c.Location, ring, key)
    45  }
    46  
    47  var (
    48  	_ jwt.SigningMethod = (*SigningMethodKMS)(nil)
    49  )
    50  
    51  func init() {
    52  	jwt.RegisterSigningMethod("KMS_RS256", func() jwt.SigningMethod {
    53  		return &SigningMethodKMS{}
    54  	})
    55  }
    56  
    57  type SigningMethodKMS struct {
    58  	Client *kms.KeyManagementClient
    59  }
    60  
    61  func NewSigningMethodKMS(client *kms.KeyManagementClient) *SigningMethodKMS {
    62  	return &SigningMethodKMS{Client: client}
    63  }
    64  
    65  func (k SigningMethodKMS) Alg() string {
    66  	return "KMS_RS256"
    67  }
    68  
    69  // Sign use during token creation, use KMS to sign the token
    70  func (k SigningMethodKMS) Sign(signingString string, key interface{}) (string, error) {
    71  	keyPath, ok := key.(string)
    72  	if !ok {
    73  		return "", fmt.Errorf("kms key must be a valid string: invalid key type: %T", key)
    74  	}
    75  	return k.signWithKMS(context.Background(), signingString, keyPath)
    76  }
    77  
    78  // Verify use during token validation, it uses a public key, no need for KMS. used by `jwt.ParseWithClaims`
    79  func (k SigningMethodKMS) Verify(signingString, signature string, key interface{}) error {
    80  	var pk *rsa.PublicKey
    81  	var ok bool
    82  	if pk, ok = key.(*rsa.PublicKey); !ok {
    83  		return fmt.Errorf("validation key must rsa.PublicKey: invalid key type: %T", key)
    84  	}
    85  	sig, err := jwt.DecodeSegment(signature)
    86  	if err != nil {
    87  		return fmt.Errorf("failed to decode signature: %w", err)
    88  	}
    89  	digest := sha256.Sum256([]byte(signingString))
    90  	if err := rsa.VerifyPSS(pk, crypto.SHA256, digest[:], sig, &rsa.PSSOptions{
    91  		SaltLength: len(digest),
    92  		Hash:       crypto.SHA256,
    93  	}); err != nil {
    94  		return fmt.Errorf("failed to verify signature: %w", err)
    95  	}
    96  	return nil
    97  }
    98  
    99  func (k SigningMethodKMS) signWithKMS(ctx context.Context, signingString string, key string) (string, error) {
   100  	plaintext := []byte(signingString)
   101  	digest := sha256.New()
   102  	if _, err := digest.Write(plaintext); err != nil {
   103  		return "", fmt.Errorf("failed to create digest: %w", err)
   104  	}
   105  
   106  	digestCRC32C := crc32c(digest.Sum(nil))
   107  
   108  	req := &kmspb.AsymmetricSignRequest{
   109  		Name: key,
   110  		Digest: &kmspb.Digest{
   111  			Digest: &kmspb.Digest_Sha256{
   112  				Sha256: digest.Sum(nil),
   113  			},
   114  		},
   115  		DigestCrc32C: wrapperspb.Int64(int64(digestCRC32C)),
   116  	}
   117  
   118  	result, err := k.Client.AsymmetricSign(ctx, req)
   119  	if err != nil {
   120  		return "", fmt.Errorf("failed to sign data: %w", err)
   121  	}
   122  
   123  	if !result.VerifiedDigestCrc32C {
   124  		return "", fmt.Errorf("AsymmetricSign: request corrupted in-transit")
   125  	}
   126  	if result.Name != req.Name {
   127  		return "", fmt.Errorf("AsymmetricSign: request corrupted in-transit")
   128  	}
   129  	if int64(crc32c(result.Signature)) != result.SignatureCrc32C.Value {
   130  		return "", fmt.Errorf("AsymmetricSign: response corrupted in-transit")
   131  	}
   132  	return jwt.EncodeSegment(result.Signature), nil
   133  }
   134  
   135  // Compute digest's CRC32C.
   136  func crc32c(data []byte) uint32 {
   137  	t := crc32.MakeTable(crc32.Castagnoli)
   138  	return crc32.Checksum(data, t)
   139  }
   140  

View as plain text