1
2
3 package sts
4
5 import (
6 "context"
7 "github.com/aws/aws-sdk-go-v2/aws"
8 "os"
9 "reflect"
10 "testing"
11 )
12
13 type mockConfigSource struct {
14 global string
15 service string
16 ignore bool
17 }
18
19
20
21 func (m mockConfigSource) GetIgnoreConfiguredEndpoints(context.Context) (bool, bool, error) {
22 return m.ignore, m.ignore, nil
23 }
24
25
26
27 func (m mockConfigSource) GetServiceBaseEndpoint(ctx context.Context, sdkID string) (string, bool, error) {
28 if m.service != "" {
29 return m.service, true, nil
30 }
31 return "", false, nil
32 }
33
34 func TestResolveBaseEndpoint(t *testing.T) {
35 cases := map[string]struct {
36 envGlobal string
37 envService string
38 envIgnore bool
39 configGlobal string
40 configService string
41 configIgnore bool
42 clientEndpoint *string
43 expectURL *string
44 }{
45 "env ignore": {
46 envGlobal: "https://env-global.dev",
47 envService: "https://env-sts.dev",
48 envIgnore: true,
49 configGlobal: "http://config-global.dev",
50 configService: "http://config-sts.dev",
51 expectURL: nil,
52 },
53 "env global": {
54 envGlobal: "https://env-global.dev",
55 configGlobal: "http://config-global.dev",
56 configService: "http://config-sts.dev",
57 expectURL: aws.String("https://env-global.dev"),
58 },
59 "env service": {
60 envGlobal: "https://env-global.dev",
61 envService: "https://env-sts.dev",
62 configGlobal: "http://config-global.dev",
63 configService: "http://config-sts.dev",
64 expectURL: aws.String("https://env-sts.dev"),
65 },
66 "config ignore": {
67 envGlobal: "https://env-global.dev",
68 envService: "https://env-sts.dev",
69 configGlobal: "http://config-global.dev",
70 configService: "http://config-sts.dev",
71 configIgnore: true,
72 expectURL: nil,
73 },
74 "config global": {
75 configGlobal: "http://config-global.dev",
76 expectURL: aws.String("http://config-global.dev"),
77 },
78 "config service": {
79 configGlobal: "http://config-global.dev",
80 configService: "http://config-sts.dev",
81 expectURL: aws.String("http://config-sts.dev"),
82 },
83 "client": {
84 envGlobal: "https://env-global.dev",
85 envService: "https://env-sts.dev",
86 configGlobal: "http://config-global.dev",
87 configService: "http://config-sts.dev",
88 clientEndpoint: aws.String("https://client-sts.dev"),
89 expectURL: aws.String("https://client-sts.dev"),
90 },
91 }
92
93 for name, c := range cases {
94 t.Run(name, func(t *testing.T) {
95 os.Clearenv()
96
97 awsConfig := aws.Config{}
98 ignore := c.envIgnore || c.configIgnore
99
100 if c.configGlobal != "" && !ignore {
101 awsConfig.BaseEndpoint = aws.String(c.configGlobal)
102 }
103
104 if c.envGlobal != "" {
105 t.Setenv("AWS_ENDPOINT_URL", c.envGlobal)
106 if !ignore {
107 awsConfig.BaseEndpoint = aws.String(c.envGlobal)
108 }
109 }
110
111 if c.envService != "" {
112 t.Setenv("AWS_ENDPOINT_URL_STS", c.envService)
113 }
114
115 awsConfig.ConfigSources = []interface{}{
116 mockConfigSource{
117 global: c.envGlobal,
118 service: c.envService,
119 ignore: c.envIgnore,
120 },
121 mockConfigSource{
122 global: c.configGlobal,
123 service: c.configService,
124 ignore: c.configIgnore,
125 },
126 }
127
128 client := NewFromConfig(awsConfig, func(o *Options) {
129 if c.clientEndpoint != nil {
130 o.BaseEndpoint = c.clientEndpoint
131 }
132 })
133
134 if e, a := c.expectURL, client.options.BaseEndpoint; !reflect.DeepEqual(e, a) {
135 t.Errorf("expect endpoint %v , got %v", e, a)
136 }
137 })
138 }
139 }
140
View as plain text