...

Source file src/google.golang.org/api/idtoken/validate_test.go

Documentation: google.golang.org/api/idtoken

     1  // Copyright 2020 Google LLC.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package idtoken
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto"
    11  	"crypto/ecdsa"
    12  	"crypto/elliptic"
    13  	"crypto/rand"
    14  	"crypto/rsa"
    15  	"encoding/base64"
    16  	"encoding/json"
    17  	"io"
    18  	"math/big"
    19  	"net/http"
    20  	"testing"
    21  	"time"
    22  
    23  	"google.golang.org/api/option"
    24  )
    25  
    26  const (
    27  	keyID              = "1234"
    28  	testAudience       = "test-audience"
    29  	expiry       int64 = 233431200
    30  )
    31  
    32  var (
    33  	beforeExp = func() time.Time { return time.Unix(expiry-1, 0) }
    34  	afterExp  = func() time.Time { return time.Unix(expiry+1, 0) }
    35  )
    36  
    37  func TestValidateRS256(t *testing.T) {
    38  	idToken, pk := createRS256JWT(t)
    39  	tests := []struct {
    40  		name    string
    41  		keyID   string
    42  		n       *big.Int
    43  		e       int
    44  		nowFunc func() time.Time
    45  		wantErr bool
    46  	}{
    47  		{
    48  			name:    "works",
    49  			keyID:   keyID,
    50  			n:       pk.N,
    51  			e:       pk.E,
    52  			nowFunc: beforeExp,
    53  			wantErr: false,
    54  		},
    55  		{
    56  			name:    "no matching key",
    57  			keyID:   "5678",
    58  			n:       pk.N,
    59  			e:       pk.E,
    60  			nowFunc: beforeExp,
    61  			wantErr: true,
    62  		},
    63  		{
    64  			name:    "sig does not match",
    65  			keyID:   keyID,
    66  			n:       new(big.Int).SetBytes([]byte("42")),
    67  			e:       42,
    68  			nowFunc: beforeExp,
    69  			wantErr: true,
    70  		},
    71  		{
    72  			name:    "token expired",
    73  			keyID:   keyID,
    74  			n:       pk.N,
    75  			e:       pk.E,
    76  			nowFunc: afterExp,
    77  			wantErr: true,
    78  		},
    79  	}
    80  
    81  	for _, tt := range tests {
    82  		t.Run(tt.name, func(t *testing.T) {
    83  			client := &http.Client{
    84  				Transport: RoundTripFn(func(req *http.Request) *http.Response {
    85  					cr := certResponse{
    86  						Keys: []jwk{
    87  							{
    88  								Kid: tt.keyID,
    89  								N:   base64.RawURLEncoding.EncodeToString(tt.n.Bytes()),
    90  								E:   base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()),
    91  							},
    92  						},
    93  					}
    94  					b, err := json.Marshal(&cr)
    95  					if err != nil {
    96  						t.Fatalf("unable to marshal response: %v", err)
    97  					}
    98  					return &http.Response{
    99  						StatusCode: 200,
   100  						Body:       io.NopCloser(bytes.NewReader(b)),
   101  						Header:     make(http.Header),
   102  					}
   103  				}),
   104  			}
   105  			oldNow := now
   106  			defer func() { now = oldNow }()
   107  			now = tt.nowFunc
   108  
   109  			v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
   110  			if err != nil {
   111  				t.Fatalf("NewValidator(...) = %q, want nil", err)
   112  			}
   113  			payload, err := v.Validate(context.Background(), idToken, testAudience)
   114  			if tt.wantErr && err != nil {
   115  				// Got the error we wanted.
   116  				return
   117  			}
   118  			if !tt.wantErr && err != nil {
   119  				t.Fatalf("Validate(ctx, %s, %s): got err %q, want nil", idToken, testAudience, err)
   120  			}
   121  			if tt.wantErr && err == nil {
   122  				t.Fatalf("Validate(ctx, %s, %s): got nil err, want err", idToken, testAudience)
   123  			}
   124  			if payload == nil {
   125  				t.Fatalf("Got nil payload, err: %v", err)
   126  			}
   127  			if payload.Audience != testAudience {
   128  				t.Fatalf("Validate(ctx, %s, %s): got %v, want %v", idToken, testAudience, payload.Audience, testAudience)
   129  			}
   130  			if len(payload.Claims) == 0 {
   131  				t.Fatalf("Validate(ctx, %s, %s): missing Claims map. payload.Claims = %+v", idToken, testAudience, payload.Claims)
   132  			}
   133  			if got, ok := payload.Claims["aud"]; !ok {
   134  				t.Fatalf("Validate(ctx, %s, %s): missing aud claim. payload.Claims = %+v", idToken, testAudience, payload.Claims)
   135  			} else {
   136  				got, ok := got.(string)
   137  				if !ok {
   138  					t.Fatalf("Validate(ctx, %s, %s): aud wasn't a string. payload.Claims = %+v", idToken, testAudience, payload.Claims)
   139  				}
   140  				if got != testAudience {
   141  					t.Fatalf("Validate(ctx, %s, %s): Payload[aud] want %v got %v", idToken, testAudience, testAudience, got)
   142  				}
   143  			}
   144  		})
   145  	}
   146  }
   147  
   148  func TestValidateES256(t *testing.T) {
   149  	idToken, pk := createES256JWT(t)
   150  	tests := []struct {
   151  		name    string
   152  		keyID   string
   153  		x       *big.Int
   154  		y       *big.Int
   155  		nowFunc func() time.Time
   156  		wantErr bool
   157  	}{
   158  		{
   159  			name:    "works",
   160  			keyID:   keyID,
   161  			x:       pk.X,
   162  			y:       pk.Y,
   163  			nowFunc: beforeExp,
   164  			wantErr: false,
   165  		},
   166  		{
   167  			name:    "no matching key",
   168  			keyID:   "5678",
   169  			x:       pk.X,
   170  			y:       pk.Y,
   171  			nowFunc: beforeExp,
   172  			wantErr: true,
   173  		},
   174  		{
   175  			name:    "sig does not match",
   176  			keyID:   keyID,
   177  			x:       new(big.Int),
   178  			y:       new(big.Int),
   179  			nowFunc: beforeExp,
   180  			wantErr: true,
   181  		},
   182  		{
   183  			name:    "token expired",
   184  			keyID:   keyID,
   185  			x:       pk.X,
   186  			y:       pk.Y,
   187  			nowFunc: afterExp,
   188  			wantErr: true,
   189  		},
   190  	}
   191  	for _, tt := range tests {
   192  		t.Run(tt.name, func(t *testing.T) {
   193  			client := &http.Client{
   194  				Transport: RoundTripFn(func(req *http.Request) *http.Response {
   195  					cr := certResponse{
   196  						Keys: []jwk{
   197  							{
   198  								Kid: tt.keyID,
   199  								X:   base64.RawURLEncoding.EncodeToString(tt.x.Bytes()),
   200  								Y:   base64.RawURLEncoding.EncodeToString(tt.y.Bytes()),
   201  							},
   202  						},
   203  					}
   204  					b, err := json.Marshal(&cr)
   205  					if err != nil {
   206  						t.Fatalf("unable to marshal response: %v", err)
   207  					}
   208  					return &http.Response{
   209  						StatusCode: 200,
   210  						Body:       io.NopCloser(bytes.NewReader(b)),
   211  						Header:     make(http.Header),
   212  					}
   213  				}),
   214  			}
   215  			oldNow := now
   216  			defer func() { now = oldNow }()
   217  			now = tt.nowFunc
   218  
   219  			v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
   220  			if err != nil {
   221  				t.Fatalf("NewValidator(...) = %q, want nil", err)
   222  			}
   223  			payload, err := v.Validate(context.Background(), idToken, testAudience)
   224  			if !tt.wantErr && err != nil {
   225  				t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
   226  			}
   227  			if !tt.wantErr && payload.Audience != testAudience {
   228  				t.Fatalf("got %v, want %v", payload.Audience, testAudience)
   229  			}
   230  		})
   231  	}
   232  }
   233  
   234  func TestParsePayload(t *testing.T) {
   235  	idToken, _ := createRS256JWT(t)
   236  	tests := []struct {
   237  		name                string
   238  		token               string
   239  		wantPayloadAudience string
   240  		wantErr             bool
   241  	}{{
   242  		name:                "valid token",
   243  		token:               idToken,
   244  		wantPayloadAudience: testAudience,
   245  	}, {
   246  		name:    "unparseable token",
   247  		token:   "aaa.bbb.ccc",
   248  		wantErr: true,
   249  	}}
   250  
   251  	for _, tt := range tests {
   252  		t.Run(tt.name, func(t *testing.T) {
   253  			payload, err := ParsePayload(tt.token)
   254  			gotErr := err != nil
   255  			if gotErr != tt.wantErr {
   256  				t.Errorf("ParsePayload(%q) got error %v, wantErr = %v", tt.token, err, tt.wantErr)
   257  			}
   258  			if tt.wantPayloadAudience != "" {
   259  				if payload == nil || payload.Audience != tt.wantPayloadAudience {
   260  					t.Errorf("ParsePayload(%q) got payload %+v, want payload with audience = %q", tt.token, payload, tt.wantPayloadAudience)
   261  				}
   262  			}
   263  		})
   264  	}
   265  }
   266  
   267  func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) {
   268  	t.Helper()
   269  	token := commonToken(t, "ES256")
   270  	privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   271  	if err != nil {
   272  		t.Fatalf("unable to generate key: %v", err)
   273  	}
   274  	r, s, err := ecdsa.Sign(rand.Reader, privateKey, token.hashedContent())
   275  	if err != nil {
   276  		t.Fatalf("unable to sign content: %v", err)
   277  	}
   278  	rb := r.Bytes()
   279  	lPadded := make([]byte, es256KeySize)
   280  	copy(lPadded[es256KeySize-len(rb):], rb)
   281  	var sig []byte
   282  	sig = append(sig, lPadded...)
   283  	sig = append(sig, s.Bytes()...)
   284  	token.signature = base64.RawURLEncoding.EncodeToString(sig)
   285  	return token.String(), privateKey.PublicKey
   286  }
   287  
   288  func createRS256JWT(t *testing.T) (string, rsa.PublicKey) {
   289  	t.Helper()
   290  	token := commonToken(t, "RS256")
   291  	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
   292  	if err != nil {
   293  		t.Fatalf("unable to generate key: %v", err)
   294  	}
   295  	sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, token.hashedContent())
   296  	if err != nil {
   297  		t.Fatalf("unable to sign content: %v", err)
   298  	}
   299  	token.signature = base64.RawURLEncoding.EncodeToString(sig)
   300  	return token.String(), privateKey.PublicKey
   301  }
   302  
   303  func commonToken(t *testing.T, alg string) *jwt {
   304  	t.Helper()
   305  	header := jwtHeader{
   306  		KeyID:     keyID,
   307  		Algorithm: alg,
   308  		Type:      "JWT",
   309  	}
   310  	payload := Payload{
   311  		Issuer:   "example.com",
   312  		Audience: testAudience,
   313  		Expires:  expiry,
   314  	}
   315  
   316  	hb, err := json.Marshal(&header)
   317  	if err != nil {
   318  		t.Fatalf("unable to marshall header: %v", err)
   319  	}
   320  	pb, err := json.Marshal(&payload)
   321  	if err != nil {
   322  		t.Fatalf("unable to marshall payload: %v", err)
   323  	}
   324  	eb := base64.RawURLEncoding.EncodeToString(hb)
   325  	ep := base64.RawURLEncoding.EncodeToString(pb)
   326  	return &jwt{
   327  		header:  eb,
   328  		payload: ep,
   329  	}
   330  }
   331  
   332  type RoundTripFn func(req *http.Request) *http.Response
   333  
   334  func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }
   335  

View as plain text