...

Source file src/cloud.google.com/go/auth/credentials/idtoken/validate_test.go

Documentation: cloud.google.com/go/auth/credentials/idtoken

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package idtoken
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto"
    21  	"crypto/ecdsa"
    22  	"crypto/elliptic"
    23  	"crypto/rand"
    24  	"crypto/rsa"
    25  	"crypto/sha256"
    26  	"encoding/base64"
    27  	"encoding/json"
    28  	"fmt"
    29  	"io"
    30  	"math/big"
    31  	"net/http"
    32  	"testing"
    33  	"time"
    34  
    35  	"cloud.google.com/go/auth/internal/jwt"
    36  )
    37  
    38  const (
    39  	keyID              = "1234"
    40  	testAudience       = "test-audience"
    41  	expiry       int64 = 233431200
    42  )
    43  
    44  var (
    45  	beforeExp = func() time.Time { return time.Unix(expiry-1, 0) }
    46  	afterExp  = func() time.Time { return time.Unix(expiry+1, 0) }
    47  )
    48  
    49  func TestValidateRS256(t *testing.T) {
    50  	idToken, pk := createRS256JWT(t)
    51  	tests := []struct {
    52  		name    string
    53  		keyID   string
    54  		n       *big.Int
    55  		e       int
    56  		nowFunc func() time.Time
    57  		wantErr bool
    58  	}{
    59  		{
    60  			name:    "works",
    61  			keyID:   keyID,
    62  			n:       pk.N,
    63  			e:       pk.E,
    64  			nowFunc: beforeExp,
    65  			wantErr: false,
    66  		},
    67  		{
    68  			name:    "no matching key",
    69  			keyID:   "5678",
    70  			n:       pk.N,
    71  			e:       pk.E,
    72  			nowFunc: beforeExp,
    73  			wantErr: true,
    74  		},
    75  		{
    76  			name:    "sig does not match",
    77  			keyID:   keyID,
    78  			n:       new(big.Int).SetBytes([]byte("42")),
    79  			e:       42,
    80  			nowFunc: beforeExp,
    81  			wantErr: true,
    82  		},
    83  		{
    84  			name:    "token expired",
    85  			keyID:   keyID,
    86  			n:       pk.N,
    87  			e:       pk.E,
    88  			nowFunc: afterExp,
    89  			wantErr: true,
    90  		},
    91  	}
    92  
    93  	for _, tt := range tests {
    94  		t.Run(tt.name, func(t *testing.T) {
    95  			client := &http.Client{
    96  				Transport: RoundTripFn(func(req *http.Request) *http.Response {
    97  					cr := certResponse{
    98  						Keys: []jwk{
    99  							{
   100  								Kid: tt.keyID,
   101  								N:   base64.RawURLEncoding.EncodeToString(tt.n.Bytes()),
   102  								E:   base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()),
   103  							},
   104  						},
   105  					}
   106  					b, err := json.Marshal(&cr)
   107  					if err != nil {
   108  						t.Fatalf("unable to marshal response: %v", err)
   109  					}
   110  					return &http.Response{
   111  						StatusCode: 200,
   112  						Body:       io.NopCloser(bytes.NewReader(b)),
   113  						Header:     make(http.Header),
   114  					}
   115  				}),
   116  			}
   117  			oldNow := now
   118  			defer func() { now = oldNow }()
   119  			now = tt.nowFunc
   120  
   121  			v, err := NewValidator(&ValidatorOptions{
   122  				Client: client,
   123  			})
   124  			if err != nil {
   125  				t.Fatalf("NewValidator(...) = %q, want nil", err)
   126  			}
   127  			payload, err := v.Validate(context.Background(), idToken, testAudience)
   128  			if tt.wantErr && err != nil {
   129  				// Got the error we wanted.
   130  				return
   131  			}
   132  			if !tt.wantErr && err != nil {
   133  				t.Fatalf("Validate(ctx, %s, %s): got err %q, want nil", idToken, testAudience, err)
   134  			}
   135  			if tt.wantErr && err == nil {
   136  				t.Fatalf("Validate(ctx, %s, %s): got nil err, want err", idToken, testAudience)
   137  			}
   138  			if payload == nil {
   139  				t.Fatalf("Got nil payload, err: %v", err)
   140  			}
   141  			if payload.Audience != testAudience {
   142  				t.Fatalf("Validate(ctx, %s, %s): got %v, want %v", idToken, testAudience, payload.Audience, testAudience)
   143  			}
   144  			if len(payload.Claims) == 0 {
   145  				t.Fatalf("Validate(ctx, %s, %s): missing Claims map. payload.Claims = %+v", idToken, testAudience, payload.Claims)
   146  			}
   147  			if got, ok := payload.Claims["aud"]; !ok {
   148  				t.Fatalf("Validate(ctx, %s, %s): missing aud claim. payload.Claims = %+v", idToken, testAudience, payload.Claims)
   149  			} else {
   150  				got, ok := got.(string)
   151  				if !ok {
   152  					t.Fatalf("Validate(ctx, %s, %s): aud wasn't a string. payload.Claims = %+v", idToken, testAudience, payload.Claims)
   153  				}
   154  				if got != testAudience {
   155  					t.Fatalf("Validate(ctx, %s, %s): Payload[aud] want %v got %v", idToken, testAudience, testAudience, got)
   156  				}
   157  			}
   158  		})
   159  	}
   160  }
   161  
   162  func TestValidateES256(t *testing.T) {
   163  	idToken, pk := createES256JWT(t)
   164  	tests := []struct {
   165  		name    string
   166  		keyID   string
   167  		x       *big.Int
   168  		y       *big.Int
   169  		nowFunc func() time.Time
   170  		wantErr bool
   171  	}{
   172  		{
   173  			name:    "works",
   174  			keyID:   keyID,
   175  			x:       pk.X,
   176  			y:       pk.Y,
   177  			nowFunc: beforeExp,
   178  			wantErr: false,
   179  		},
   180  		{
   181  			name:    "no matching key",
   182  			keyID:   "5678",
   183  			x:       pk.X,
   184  			y:       pk.Y,
   185  			nowFunc: beforeExp,
   186  			wantErr: true,
   187  		},
   188  		{
   189  			name:    "sig does not match",
   190  			keyID:   keyID,
   191  			x:       new(big.Int),
   192  			y:       new(big.Int),
   193  			nowFunc: beforeExp,
   194  			wantErr: true,
   195  		},
   196  		{
   197  			name:    "token expired",
   198  			keyID:   keyID,
   199  			x:       pk.X,
   200  			y:       pk.Y,
   201  			nowFunc: afterExp,
   202  			wantErr: true,
   203  		},
   204  	}
   205  	for _, tt := range tests {
   206  		t.Run(tt.name, func(t *testing.T) {
   207  			client := &http.Client{
   208  				Transport: RoundTripFn(func(req *http.Request) *http.Response {
   209  					cr := certResponse{
   210  						Keys: []jwk{
   211  							{
   212  								Kid: tt.keyID,
   213  								X:   base64.RawURLEncoding.EncodeToString(tt.x.Bytes()),
   214  								Y:   base64.RawURLEncoding.EncodeToString(tt.y.Bytes()),
   215  							},
   216  						},
   217  					}
   218  					b, err := json.Marshal(&cr)
   219  					if err != nil {
   220  						t.Fatalf("unable to marshal response: %v", err)
   221  					}
   222  					return &http.Response{
   223  						StatusCode: 200,
   224  						Body:       io.NopCloser(bytes.NewReader(b)),
   225  						Header:     make(http.Header),
   226  					}
   227  				}),
   228  			}
   229  			oldNow := now
   230  			defer func() { now = oldNow }()
   231  			now = tt.nowFunc
   232  
   233  			v, err := NewValidator(&ValidatorOptions{
   234  				Client: client,
   235  			})
   236  			if err != nil {
   237  				t.Fatalf("NewValidator(...) = %q, want nil", err)
   238  			}
   239  			payload, err := v.Validate(context.Background(), idToken, testAudience)
   240  			if !tt.wantErr && err != nil {
   241  				t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
   242  			}
   243  			if !tt.wantErr && payload.Audience != testAudience {
   244  				t.Fatalf("got %v, want %v", payload.Audience, testAudience)
   245  			}
   246  		})
   247  	}
   248  }
   249  
   250  func TestParsePayload(t *testing.T) {
   251  	idToken, _ := createRS256JWT(t)
   252  	tests := []struct {
   253  		name                string
   254  		token               string
   255  		wantPayloadAudience string
   256  		wantErr             bool
   257  	}{{
   258  		name:                "valid token",
   259  		token:               idToken,
   260  		wantPayloadAudience: testAudience,
   261  	}, {
   262  		name:    "unparseable token",
   263  		token:   "aaa.bbb.ccc",
   264  		wantErr: true,
   265  	}}
   266  
   267  	for _, tt := range tests {
   268  		t.Run(tt.name, func(t *testing.T) {
   269  			payload, err := ParsePayload(tt.token)
   270  			gotErr := err != nil
   271  			if gotErr != tt.wantErr {
   272  				t.Errorf("ParsePayload(%q) got error %v, wantErr = %v", tt.token, err, tt.wantErr)
   273  			}
   274  			if tt.wantPayloadAudience != "" {
   275  				if payload == nil || payload.Audience != tt.wantPayloadAudience {
   276  					t.Errorf("ParsePayload(%q) got payload %+v, want payload with audience = %q", tt.token, payload, tt.wantPayloadAudience)
   277  				}
   278  			}
   279  		})
   280  	}
   281  }
   282  
   283  func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) {
   284  	t.Helper()
   285  	header, claims := commonToken(t, "ES256")
   286  	privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   287  	if err != nil {
   288  		t.Fatalf("unable to generate key: %v", err)
   289  	}
   290  	signedContent := header + "." + claims
   291  	hashed := sha256.Sum256([]byte(signedContent))
   292  	hash := hashed[:]
   293  	r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash)
   294  	if err != nil {
   295  		t.Fatalf("unable to sign content: %v", err)
   296  	}
   297  	rb := r.Bytes()
   298  	lPadded := make([]byte, es256KeySize)
   299  	copy(lPadded[es256KeySize-len(rb):], rb)
   300  	var sig []byte
   301  	sig = append(sig, lPadded...)
   302  	sig = append(sig, s.Bytes()...)
   303  	signature := base64.RawURLEncoding.EncodeToString(sig)
   304  	return fmt.Sprintf("%s.%s.%s", header, claims, signature), privateKey.PublicKey
   305  }
   306  
   307  func createRS256JWT(t *testing.T) (string, rsa.PublicKey) {
   308  	t.Helper()
   309  	header, claims := commonToken(t, jwt.HeaderAlgRSA256)
   310  	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
   311  	if err != nil {
   312  		t.Fatalf("unable to generate key: %v", err)
   313  	}
   314  	signedContent := header + "." + claims
   315  	hashed := sha256.Sum256([]byte(signedContent))
   316  	hash := hashed[:]
   317  	sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hash)
   318  	if err != nil {
   319  		t.Fatalf("unable to sign content: %v", err)
   320  	}
   321  	signature := base64.RawURLEncoding.EncodeToString(sig)
   322  	return fmt.Sprintf("%s.%s.%s", header, claims, signature), privateKey.PublicKey
   323  }
   324  
   325  // returns header and claims
   326  func commonToken(t *testing.T, alg string) (string, string) {
   327  	t.Helper()
   328  	header := jwt.Header{
   329  		KeyID:     keyID,
   330  		Algorithm: alg,
   331  		Type:      jwt.HeaderType,
   332  	}
   333  	payload := Payload{
   334  		Issuer:   "example.com",
   335  		Audience: testAudience,
   336  		Expires:  expiry,
   337  	}
   338  
   339  	hb, err := json.Marshal(&header)
   340  	if err != nil {
   341  		t.Fatalf("unable to marshall header: %v", err)
   342  	}
   343  	pb, err := json.Marshal(&payload)
   344  	if err != nil {
   345  		t.Fatalf("unable to marshall payload: %v", err)
   346  	}
   347  	eb := base64.RawURLEncoding.EncodeToString(hb)
   348  	ep := base64.RawURLEncoding.EncodeToString(pb)
   349  	return eb, ep
   350  }
   351  
   352  type RoundTripFn func(req *http.Request) *http.Response
   353  
   354  func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }
   355  

View as plain text