...

Source file src/golang.org/x/oauth2/jwt/jwt_test.go

Documentation: golang.org/x/oauth2/jwt

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jwt
     6  
     7  import (
     8  	"context"
     9  	"encoding/base64"
    10  	"encoding/json"
    11  	"fmt"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"reflect"
    15  	"strings"
    16  	"testing"
    17  
    18  	"golang.org/x/oauth2"
    19  	"golang.org/x/oauth2/jws"
    20  )
    21  
    22  var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
    23  MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE
    24  DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY
    25  fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK
    26  1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr
    27  k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9
    28  /E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt
    29  3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn
    30  2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3
    31  nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK
    32  6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf
    33  5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e
    34  DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1
    35  M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g
    36  z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y
    37  1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK
    38  J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U
    39  f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx
    40  QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA
    41  cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr
    42  Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw
    43  5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg
    44  KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84
    45  OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd
    46  mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ
    47  5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg==
    48  -----END RSA PRIVATE KEY-----`)
    49  
    50  func TestJWTFetch_JSONResponse(t *testing.T) {
    51  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    52  		w.Header().Set("Content-Type", "application/json")
    53  		w.Write([]byte(`{
    54  			"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
    55  			"scope": "user",
    56  			"token_type": "bearer",
    57  			"expires_in": 3600
    58  		}`))
    59  	}))
    60  	defer ts.Close()
    61  
    62  	conf := &Config{
    63  		Email:      "aaa@xxx.com",
    64  		PrivateKey: dummyPrivateKey,
    65  		TokenURL:   ts.URL,
    66  	}
    67  	tok, err := conf.TokenSource(context.Background()).Token()
    68  	if err != nil {
    69  		t.Fatal(err)
    70  	}
    71  	if !tok.Valid() {
    72  		t.Errorf("got invalid token: %v", tok)
    73  	}
    74  	if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
    75  		t.Errorf("access token = %q; want %q", got, want)
    76  	}
    77  	if got, want := tok.TokenType, "bearer"; got != want {
    78  		t.Errorf("token type = %q; want %q", got, want)
    79  	}
    80  	if got := tok.Expiry.IsZero(); got {
    81  		t.Errorf("token expiry = %v, want none", got)
    82  	}
    83  	scope := tok.Extra("scope")
    84  	if got, want := scope, "user"; got != want {
    85  		t.Errorf("scope = %q; want %q", got, want)
    86  	}
    87  }
    88  
    89  func TestJWTFetch_BadResponse(t *testing.T) {
    90  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    91  		w.Header().Set("Content-Type", "application/json")
    92  		w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
    93  	}))
    94  	defer ts.Close()
    95  
    96  	conf := &Config{
    97  		Email:      "aaa@xxx.com",
    98  		PrivateKey: dummyPrivateKey,
    99  		TokenURL:   ts.URL,
   100  	}
   101  	tok, err := conf.TokenSource(context.Background()).Token()
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  	if tok == nil {
   106  		t.Fatalf("got nil token; want token")
   107  	}
   108  	if tok.Valid() {
   109  		t.Errorf("got invalid token: %v", tok)
   110  	}
   111  	if got, want := tok.AccessToken, ""; got != want {
   112  		t.Errorf("access token = %q; want %q", got, want)
   113  	}
   114  	if got, want := tok.TokenType, "bearer"; got != want {
   115  		t.Errorf("token type = %q; want %q", got, want)
   116  	}
   117  	scope := tok.Extra("scope")
   118  	if got, want := scope, "user"; got != want {
   119  		t.Errorf("token scope = %q; want %q", got, want)
   120  	}
   121  }
   122  
   123  func TestJWTFetch_BadResponseType(t *testing.T) {
   124  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   125  		w.Header().Set("Content-Type", "application/json")
   126  		w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
   127  	}))
   128  	defer ts.Close()
   129  	conf := &Config{
   130  		Email:      "aaa@xxx.com",
   131  		PrivateKey: dummyPrivateKey,
   132  		TokenURL:   ts.URL,
   133  	}
   134  	tok, err := conf.TokenSource(context.Background()).Token()
   135  	if err == nil {
   136  		t.Error("got a token; expected error")
   137  		if got, want := tok.AccessToken, ""; got != want {
   138  			t.Errorf("access token = %q; want %q", got, want)
   139  		}
   140  	}
   141  }
   142  
   143  func TestJWTFetch_Assertion(t *testing.T) {
   144  	var assertion string
   145  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   146  		r.ParseForm()
   147  		assertion = r.Form.Get("assertion")
   148  
   149  		w.Header().Set("Content-Type", "application/json")
   150  		w.Write([]byte(`{
   151  			"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
   152  			"scope": "user",
   153  			"token_type": "bearer",
   154  			"expires_in": 3600
   155  		}`))
   156  	}))
   157  	defer ts.Close()
   158  
   159  	conf := &Config{
   160  		Email:        "aaa@xxx.com",
   161  		PrivateKey:   dummyPrivateKey,
   162  		PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   163  		TokenURL:     ts.URL,
   164  	}
   165  
   166  	_, err := conf.TokenSource(context.Background()).Token()
   167  	if err != nil {
   168  		t.Fatalf("Failed to fetch token: %v", err)
   169  	}
   170  
   171  	parts := strings.Split(assertion, ".")
   172  	if len(parts) != 3 {
   173  		t.Fatalf("assertion = %q; want 3 parts", assertion)
   174  	}
   175  	gotjson, err := base64.RawURLEncoding.DecodeString(parts[0])
   176  	if err != nil {
   177  		t.Fatalf("invalid token header; err = %v", err)
   178  	}
   179  
   180  	got := jws.Header{}
   181  	if err := json.Unmarshal(gotjson, &got); err != nil {
   182  		t.Errorf("failed to unmarshal json token header = %q; err = %v", gotjson, err)
   183  	}
   184  
   185  	want := jws.Header{
   186  		Algorithm: "RS256",
   187  		Typ:       "JWT",
   188  		KeyID:     "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   189  	}
   190  	if got != want {
   191  		t.Errorf("access token header = %q; want %q", got, want)
   192  	}
   193  }
   194  
   195  func TestJWTFetch_AssertionPayload(t *testing.T) {
   196  	var assertion string
   197  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   198  		r.ParseForm()
   199  		assertion = r.Form.Get("assertion")
   200  
   201  		w.Header().Set("Content-Type", "application/json")
   202  		w.Write([]byte(`{
   203  			"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
   204  			"scope": "user",
   205  			"token_type": "bearer",
   206  			"expires_in": 3600
   207  		}`))
   208  	}))
   209  	defer ts.Close()
   210  
   211  	for _, conf := range []*Config{
   212  		{
   213  			Email:        "aaa1@xxx.com",
   214  			PrivateKey:   dummyPrivateKey,
   215  			PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   216  			TokenURL:     ts.URL,
   217  		},
   218  		{
   219  			Email:        "aaa2@xxx.com",
   220  			PrivateKey:   dummyPrivateKey,
   221  			PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   222  			TokenURL:     ts.URL,
   223  			Audience:     "https://example.com",
   224  		},
   225  		{
   226  			Email:        "aaa2@xxx.com",
   227  			PrivateKey:   dummyPrivateKey,
   228  			PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   229  			TokenURL:     ts.URL,
   230  			PrivateClaims: map[string]interface{}{
   231  				"private0": "claim0",
   232  				"private1": "claim1",
   233  			},
   234  		},
   235  	} {
   236  		t.Run(conf.Email, func(t *testing.T) {
   237  			_, err := conf.TokenSource(context.Background()).Token()
   238  			if err != nil {
   239  				t.Fatalf("Failed to fetch token: %v", err)
   240  			}
   241  
   242  			parts := strings.Split(assertion, ".")
   243  			if len(parts) != 3 {
   244  				t.Fatalf("assertion = %q; want 3 parts", assertion)
   245  			}
   246  			gotjson, err := base64.RawURLEncoding.DecodeString(parts[1])
   247  			if err != nil {
   248  				t.Fatalf("invalid token payload; err = %v", err)
   249  			}
   250  
   251  			claimSet := jws.ClaimSet{}
   252  			if err := json.Unmarshal(gotjson, &claimSet); err != nil {
   253  				t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err)
   254  			}
   255  
   256  			if got, want := claimSet.Iss, conf.Email; got != want {
   257  				t.Errorf("payload email = %q; want %q", got, want)
   258  			}
   259  			if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want {
   260  				t.Errorf("payload scope = %q; want %q", got, want)
   261  			}
   262  			aud := conf.TokenURL
   263  			if conf.Audience != "" {
   264  				aud = conf.Audience
   265  			}
   266  			if got, want := claimSet.Aud, aud; got != want {
   267  				t.Errorf("payload audience = %q; want %q", got, want)
   268  			}
   269  			if got, want := claimSet.Sub, conf.Subject; got != want {
   270  				t.Errorf("payload subject = %q; want %q", got, want)
   271  			}
   272  			if got, want := claimSet.Prn, conf.Subject; got != want {
   273  				t.Errorf("payload prn = %q; want %q", got, want)
   274  			}
   275  			if len(conf.PrivateClaims) > 0 {
   276  				var got interface{}
   277  				if err := json.Unmarshal(gotjson, &got); err != nil {
   278  					t.Errorf("failed to parse payload; err = %q", err)
   279  				}
   280  				m := got.(map[string]interface{})
   281  				for v, k := range conf.PrivateClaims {
   282  					if !reflect.DeepEqual(m[v], k) {
   283  						t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k)
   284  					}
   285  				}
   286  			}
   287  		})
   288  	}
   289  }
   290  
   291  func TestTokenRetrieveError(t *testing.T) {
   292  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   293  		w.Header().Set("Content-type", "application/json")
   294  		w.WriteHeader(http.StatusBadRequest)
   295  		w.Write([]byte(`{"error": "invalid_grant"}`))
   296  	}))
   297  	defer ts.Close()
   298  
   299  	conf := &Config{
   300  		Email:      "aaa@xxx.com",
   301  		PrivateKey: dummyPrivateKey,
   302  		TokenURL:   ts.URL,
   303  	}
   304  
   305  	_, err := conf.TokenSource(context.Background()).Token()
   306  	if err == nil {
   307  		t.Fatalf("got no error, expected one")
   308  	}
   309  	_, ok := err.(*oauth2.RetrieveError)
   310  	if !ok {
   311  		t.Fatalf("got %T error, expected *RetrieveError", err)
   312  	}
   313  	// Test error string for backwards compatibility
   314  	expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
   315  	if errStr := err.Error(); errStr != expected {
   316  		t.Fatalf("got %#v, expected %#v", errStr, expected)
   317  	}
   318  }
   319  

View as plain text