1 package main
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 import (
18 "crypto/rsa"
19 "crypto/x509"
20 "encoding/json"
21 "flag"
22 "fmt"
23 "io/ioutil"
24 "log"
25 "net/http"
26 "strings"
27
28 "github.com/Azure/go-autorest/autorest"
29 "github.com/Azure/go-autorest/autorest/adal"
30 "github.com/Azure/go-autorest/autorest/azure"
31 "golang.org/x/crypto/pkcs12"
32 )
33
34 const (
35 resourceGroupURLTemplate = "https://management.azure.com"
36 apiVersion = "2015-01-01"
37 nativeAppClientID = "a87032a7-203c-4bf7-913c-44c50d23409a"
38 resource = "https://management.core.windows.net/"
39 )
40
41 var (
42 mode string
43 tenantID string
44 subscriptionID string
45 applicationID string
46
47 tokenCachePath string
48 forceRefresh bool
49 impatient bool
50
51 certificatePath string
52 )
53
54 func init() {
55 flag.StringVar(&mode, "mode", "device", "mode of operation for SPT creation")
56 flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/pfx certificate")
57 flag.StringVar(&applicationID, "applicationId", "", "application id")
58 flag.StringVar(&tenantID, "tenantId", "", "tenant id")
59 flag.StringVar(&subscriptionID, "subscriptionId", "", "subscription id")
60 flag.StringVar(&tokenCachePath, "tokenCachePath", "", "location of oauth token cache")
61 flag.BoolVar(&forceRefresh, "forceRefresh", false, "pass true to force a token refresh")
62
63 flag.Parse()
64
65 log.Printf("mode(%s) certPath(%s) appID(%s) tenantID(%s), subID(%s)\n",
66 mode, certificatePath, applicationID, tenantID, subscriptionID)
67
68 if mode == "certificate" &&
69 (strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "") {
70 log.Fatalln("Bad usage. Using certificate mode. Please specify tenantID, subscriptionID")
71 }
72
73 if mode != "certificate" && mode != "device" {
74 log.Fatalln("Bad usage. Mode must be one of 'certificate' or 'device'.")
75 }
76
77 if mode == "device" && strings.TrimSpace(applicationID) == "" {
78 log.Println("Using device mode auth. Will use `azkube` clientID since none was specified on the comand line.")
79 applicationID = nativeAppClientID
80 }
81
82 if mode == "certificate" && strings.TrimSpace(certificatePath) == "" {
83 log.Fatalln("Bad usage. Mode 'certificate' requires the 'certificatePath' argument.")
84 }
85
86 if strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "" || strings.TrimSpace(applicationID) == "" {
87 log.Fatalln("Bad usage. Must specify the 'tenantId' and 'subscriptionId'")
88 }
89 }
90
91 func getSptFromCachedToken(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
92 token, err := adal.LoadToken(tokenCachePath)
93 if err != nil {
94 return nil, fmt.Errorf("failed to load token from cache: %v", err)
95 }
96
97 spt, _ := adal.NewServicePrincipalTokenFromManualToken(
98 oauthConfig,
99 clientID,
100 resource,
101 *token,
102 callbacks...)
103
104 return spt, nil
105 }
106
107 func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
108 privateKey, certificate, err := pkcs12.Decode(pkcs, password)
109 if err != nil {
110 return nil, nil, err
111 }
112
113 rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
114 if !isRsaKey {
115 return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
116 }
117
118 return certificate, rsaPrivateKey, nil
119 }
120
121 func getSptFromCertificate(oauthConfig adal.OAuthConfig, clientID, resource, certicatePath string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
122 certData, err := ioutil.ReadFile(certificatePath)
123 if err != nil {
124 return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
125 }
126
127 certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
128 if err != nil {
129 return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
130 }
131
132 spt, _ := adal.NewServicePrincipalTokenFromCertificate(
133 oauthConfig,
134 clientID,
135 certificate,
136 rsaPrivateKey,
137 resource,
138 callbacks...)
139
140 return spt, nil
141 }
142
143 func getSptFromDeviceFlow(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
144 oauthClient := &autorest.Client{}
145 deviceCode, err := adal.InitiateDeviceAuth(oauthClient, oauthConfig, clientID, resource)
146 if err != nil {
147 return nil, fmt.Errorf("failed to start device auth flow: %s", err)
148 }
149
150 fmt.Println(*deviceCode.Message)
151
152 token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
153 if err != nil {
154 return nil, fmt.Errorf("failed to finish device auth flow: %s", err)
155 }
156
157 spt, err := adal.NewServicePrincipalTokenFromManualToken(
158 oauthConfig,
159 clientID,
160 resource,
161 *token,
162 callbacks...)
163 if err != nil {
164 return nil, fmt.Errorf("failed to get oauth token from device flow: %v", err)
165 }
166
167 return spt, nil
168 }
169
170 func printResourceGroups(client *autorest.Client) error {
171 p := map[string]interface{}{"subscription-id": subscriptionID}
172 q := map[string]interface{}{"api-version": apiVersion}
173
174 req, _ := autorest.Prepare(&http.Request{},
175 autorest.AsGet(),
176 autorest.WithBaseURL(resourceGroupURLTemplate),
177 autorest.WithPathParameters("/subscriptions/{subscription-id}/resourcegroups", p),
178 autorest.WithQueryParameters(q))
179
180 resp, err := autorest.SendWithSender(client, req)
181 if err != nil {
182 return err
183 }
184
185 value := struct {
186 ResourceGroups []struct {
187 Name string `json:"name"`
188 } `json:"value"`
189 }{}
190
191 defer resp.Body.Close()
192 dec := json.NewDecoder(resp.Body)
193 err = dec.Decode(&value)
194 if err != nil {
195 return err
196 }
197
198 var groupNames = make([]string, len(value.ResourceGroups))
199 for i, name := range value.ResourceGroups {
200 groupNames[i] = name.Name
201 }
202
203 log.Println("Groups:", strings.Join(groupNames, ", "))
204 return err
205 }
206
207 func saveToken(spt adal.Token) {
208 if tokenCachePath != "" {
209 err := adal.SaveToken(tokenCachePath, 0600, spt)
210 if err != nil {
211 log.Println("error saving token", err)
212 } else {
213 log.Println("saved token to", tokenCachePath)
214 }
215 }
216 }
217
218 func main() {
219 var spt *adal.ServicePrincipalToken
220 var err error
221
222 callback := func(t adal.Token) error {
223 log.Println("refresh callback was called")
224 saveToken(t)
225 return nil
226 }
227
228 oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, tenantID)
229 if err != nil {
230 panic(err)
231 }
232
233 if tokenCachePath != "" {
234 log.Println("tokenCachePath specified; attempting to load from", tokenCachePath)
235 spt, err = getSptFromCachedToken(*oauthConfig, applicationID, resource, callback)
236 if err != nil {
237 spt = nil
238 log.Println("loading from cache failed:", err)
239 }
240 }
241
242 if spt == nil {
243 log.Println("authenticating via 'mode'", mode)
244 switch mode {
245 case "device":
246 spt, err = getSptFromDeviceFlow(*oauthConfig, applicationID, resource, callback)
247 case "certificate":
248 spt, err = getSptFromCertificate(*oauthConfig, applicationID, resource, certificatePath, callback)
249 }
250 if err != nil {
251 log.Fatalln("failed to retrieve token:", err)
252 }
253
254
255 if tokenCachePath != "" {
256 saveToken(spt.Token())
257 }
258 }
259
260 client := &autorest.Client{}
261 client.Authorizer = autorest.NewBearerAuthorizer(spt)
262
263 printResourceGroups(client)
264
265 if forceRefresh {
266 err = spt.Refresh()
267 if err != nil {
268 panic(err)
269 }
270 printResourceGroups(client)
271 }
272 }
273
View as plain text