...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/ssocreds/sso_cached_token.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/ssocreds

     1  package ssocreds
     2  
     3  import (
     4  	"crypto/sha1"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"os"
    10  	"path/filepath"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
    16  	"github.com/aws/aws-sdk-go-v2/internal/shareddefaults"
    17  )
    18  
    19  var osUserHomeDur = shareddefaults.UserHomeDir
    20  
    21  // StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or
    22  // error if unable get derive the path. Key that will be used to compute a SHA1
    23  // value that is hex encoded.
    24  //
    25  // Derives the filepath using the Key as:
    26  //
    27  //	~/.aws/sso/cache/<sha1-hex-encoded-key>.json
    28  func StandardCachedTokenFilepath(key string) (string, error) {
    29  	homeDir := osUserHomeDur()
    30  	if len(homeDir) == 0 {
    31  		return "", fmt.Errorf("unable to get USER's home directory for cached token")
    32  	}
    33  	hash := sha1.New()
    34  	if _, err := hash.Write([]byte(key)); err != nil {
    35  		return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %w", err)
    36  	}
    37  
    38  	cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json"
    39  
    40  	return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil
    41  }
    42  
    43  type tokenKnownFields struct {
    44  	AccessToken string   `json:"accessToken,omitempty"`
    45  	ExpiresAt   *rfc3339 `json:"expiresAt,omitempty"`
    46  
    47  	RefreshToken string `json:"refreshToken,omitempty"`
    48  	ClientID     string `json:"clientId,omitempty"`
    49  	ClientSecret string `json:"clientSecret,omitempty"`
    50  }
    51  
    52  type token struct {
    53  	tokenKnownFields
    54  	UnknownFields map[string]interface{} `json:"-"`
    55  }
    56  
    57  func (t token) MarshalJSON() ([]byte, error) {
    58  	fields := map[string]interface{}{}
    59  
    60  	setTokenFieldString(fields, "accessToken", t.AccessToken)
    61  	setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt)
    62  
    63  	setTokenFieldString(fields, "refreshToken", t.RefreshToken)
    64  	setTokenFieldString(fields, "clientId", t.ClientID)
    65  	setTokenFieldString(fields, "clientSecret", t.ClientSecret)
    66  
    67  	for k, v := range t.UnknownFields {
    68  		if _, ok := fields[k]; ok {
    69  			return nil, fmt.Errorf("unknown token field %v, duplicates known field", k)
    70  		}
    71  		fields[k] = v
    72  	}
    73  
    74  	return json.Marshal(fields)
    75  }
    76  
    77  func setTokenFieldString(fields map[string]interface{}, key, value string) {
    78  	if value == "" {
    79  		return
    80  	}
    81  	fields[key] = value
    82  }
    83  func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) {
    84  	if value == nil {
    85  		return
    86  	}
    87  	fields[key] = value
    88  }
    89  
    90  func (t *token) UnmarshalJSON(b []byte) error {
    91  	var fields map[string]interface{}
    92  	if err := json.Unmarshal(b, &fields); err != nil {
    93  		return nil
    94  	}
    95  
    96  	t.UnknownFields = map[string]interface{}{}
    97  
    98  	for k, v := range fields {
    99  		var err error
   100  		switch k {
   101  		case "accessToken":
   102  			err = getTokenFieldString(v, &t.AccessToken)
   103  		case "expiresAt":
   104  			err = getTokenFieldRFC3339(v, &t.ExpiresAt)
   105  		case "refreshToken":
   106  			err = getTokenFieldString(v, &t.RefreshToken)
   107  		case "clientId":
   108  			err = getTokenFieldString(v, &t.ClientID)
   109  		case "clientSecret":
   110  			err = getTokenFieldString(v, &t.ClientSecret)
   111  		default:
   112  			t.UnknownFields[k] = v
   113  		}
   114  
   115  		if err != nil {
   116  			return fmt.Errorf("field %q, %w", k, err)
   117  		}
   118  	}
   119  
   120  	return nil
   121  }
   122  
   123  func getTokenFieldString(v interface{}, value *string) error {
   124  	var ok bool
   125  	*value, ok = v.(string)
   126  	if !ok {
   127  		return fmt.Errorf("expect value to be string, got %T", v)
   128  	}
   129  	return nil
   130  }
   131  
   132  func getTokenFieldRFC3339(v interface{}, value **rfc3339) error {
   133  	var stringValue string
   134  	if err := getTokenFieldString(v, &stringValue); err != nil {
   135  		return err
   136  	}
   137  
   138  	timeValue, err := parseRFC3339(stringValue)
   139  	if err != nil {
   140  		return err
   141  	}
   142  
   143  	*value = &timeValue
   144  	return nil
   145  }
   146  
   147  func loadCachedToken(filename string) (token, error) {
   148  	fileBytes, err := ioutil.ReadFile(filename)
   149  	if err != nil {
   150  		return token{}, fmt.Errorf("failed to read cached SSO token file, %w", err)
   151  	}
   152  
   153  	var t token
   154  	if err := json.Unmarshal(fileBytes, &t); err != nil {
   155  		return token{}, fmt.Errorf("failed to parse cached SSO token file, %w", err)
   156  	}
   157  
   158  	if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() {
   159  		return token{}, fmt.Errorf(
   160  			"cached SSO token must contain accessToken and expiresAt fields")
   161  	}
   162  
   163  	return t, nil
   164  }
   165  
   166  func storeCachedToken(filename string, t token, fileMode os.FileMode) (err error) {
   167  	tmpFilename := filename + ".tmp-" + strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
   168  	if err := writeCacheFile(tmpFilename, fileMode, t); err != nil {
   169  		return err
   170  	}
   171  
   172  	if err := os.Rename(tmpFilename, filename); err != nil {
   173  		return fmt.Errorf("failed to replace old cached SSO token file, %w", err)
   174  	}
   175  
   176  	return nil
   177  }
   178  
   179  func writeCacheFile(filename string, fileMode os.FileMode, t token) (err error) {
   180  	var f *os.File
   181  	f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode)
   182  	if err != nil {
   183  		return fmt.Errorf("failed to create cached SSO token file %w", err)
   184  	}
   185  
   186  	defer func() {
   187  		closeErr := f.Close()
   188  		if err == nil && closeErr != nil {
   189  			err = fmt.Errorf("failed to close cached SSO token file, %w", closeErr)
   190  		}
   191  	}()
   192  
   193  	encoder := json.NewEncoder(f)
   194  
   195  	if err = encoder.Encode(t); err != nil {
   196  		return fmt.Errorf("failed to serialize cached SSO token, %w", err)
   197  	}
   198  
   199  	return nil
   200  }
   201  
   202  type rfc3339 time.Time
   203  
   204  func parseRFC3339(v string) (rfc3339, error) {
   205  	parsed, err := time.Parse(time.RFC3339, v)
   206  	if err != nil {
   207  		return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %w", err)
   208  	}
   209  
   210  	return rfc3339(parsed), nil
   211  }
   212  
   213  func (r *rfc3339) UnmarshalJSON(bytes []byte) (err error) {
   214  	var value string
   215  
   216  	// Use JSON unmarshal to unescape the quoted value making use of JSON's
   217  	// unquoting rules.
   218  	if err = json.Unmarshal(bytes, &value); err != nil {
   219  		return err
   220  	}
   221  
   222  	*r, err = parseRFC3339(value)
   223  
   224  	return nil
   225  }
   226  
   227  func (r *rfc3339) MarshalJSON() ([]byte, error) {
   228  	value := time.Time(*r).Format(time.RFC3339)
   229  
   230  	// Use JSON unmarshal to unescape the quoted value making use of JSON's
   231  	// quoting rules.
   232  	return json.Marshal(value)
   233  }
   234  

View as plain text