...

Source file src/github.com/Azure/go-autorest/autorest/azure/example/main.go

Documentation: github.com/Azure/go-autorest/autorest/azure/example

     1  package main
     2  
     3  // Copyright 2017 Microsoft Corporation
     4  //
     5  //  Licensed under the Apache License, Version 2.0 (the "License");
     6  //  you may not use this file except in compliance with the License.
     7  //  You may obtain a copy of the License at
     8  //
     9  //      http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  //  Unless required by applicable law or agreed to in writing, software
    12  //  distributed under the License is distributed on an "AS IS" BASIS,
    13  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  //  See the License for the specific language governing permissions and
    15  //  limitations under the License.
    16  
    17  import (
    18  	"crypto/rsa"
    19  	"crypto/x509"
    20  	"encoding/json"
    21  	"flag"
    22  	"fmt"
    23  	"io/ioutil"
    24  	"log"
    25  	"net/http"
    26  	"strings"
    27  
    28  	"github.com/Azure/go-autorest/autorest"
    29  	"github.com/Azure/go-autorest/autorest/adal"
    30  	"github.com/Azure/go-autorest/autorest/azure"
    31  	"golang.org/x/crypto/pkcs12"
    32  )
    33  
    34  const (
    35  	resourceGroupURLTemplate = "https://management.azure.com"
    36  	apiVersion               = "2015-01-01"
    37  	nativeAppClientID        = "a87032a7-203c-4bf7-913c-44c50d23409a"
    38  	resource                 = "https://management.core.windows.net/"
    39  )
    40  
    41  var (
    42  	mode           string
    43  	tenantID       string
    44  	subscriptionID string
    45  	applicationID  string
    46  
    47  	tokenCachePath string
    48  	forceRefresh   bool
    49  	impatient      bool
    50  
    51  	certificatePath string
    52  )
    53  
    54  func init() {
    55  	flag.StringVar(&mode, "mode", "device", "mode of operation for SPT creation")
    56  	flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/pfx certificate")
    57  	flag.StringVar(&applicationID, "applicationId", "", "application id")
    58  	flag.StringVar(&tenantID, "tenantId", "", "tenant id")
    59  	flag.StringVar(&subscriptionID, "subscriptionId", "", "subscription id")
    60  	flag.StringVar(&tokenCachePath, "tokenCachePath", "", "location of oauth token cache")
    61  	flag.BoolVar(&forceRefresh, "forceRefresh", false, "pass true to force a token refresh")
    62  
    63  	flag.Parse()
    64  
    65  	log.Printf("mode(%s) certPath(%s) appID(%s) tenantID(%s), subID(%s)\n",
    66  		mode, certificatePath, applicationID, tenantID, subscriptionID)
    67  
    68  	if mode == "certificate" &&
    69  		(strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "") {
    70  		log.Fatalln("Bad usage. Using certificate mode. Please specify tenantID, subscriptionID")
    71  	}
    72  
    73  	if mode != "certificate" && mode != "device" {
    74  		log.Fatalln("Bad usage. Mode must be one of 'certificate' or 'device'.")
    75  	}
    76  
    77  	if mode == "device" && strings.TrimSpace(applicationID) == "" {
    78  		log.Println("Using device mode auth. Will use `azkube` clientID since none was specified on the comand line.")
    79  		applicationID = nativeAppClientID
    80  	}
    81  
    82  	if mode == "certificate" && strings.TrimSpace(certificatePath) == "" {
    83  		log.Fatalln("Bad usage. Mode 'certificate' requires the 'certificatePath' argument.")
    84  	}
    85  
    86  	if strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "" || strings.TrimSpace(applicationID) == "" {
    87  		log.Fatalln("Bad usage. Must specify the 'tenantId' and 'subscriptionId'")
    88  	}
    89  }
    90  
    91  func getSptFromCachedToken(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
    92  	token, err := adal.LoadToken(tokenCachePath)
    93  	if err != nil {
    94  		return nil, fmt.Errorf("failed to load token from cache: %v", err)
    95  	}
    96  
    97  	spt, _ := adal.NewServicePrincipalTokenFromManualToken(
    98  		oauthConfig,
    99  		clientID,
   100  		resource,
   101  		*token,
   102  		callbacks...)
   103  
   104  	return spt, nil
   105  }
   106  
   107  func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
   108  	privateKey, certificate, err := pkcs12.Decode(pkcs, password)
   109  	if err != nil {
   110  		return nil, nil, err
   111  	}
   112  
   113  	rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
   114  	if !isRsaKey {
   115  		return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
   116  	}
   117  
   118  	return certificate, rsaPrivateKey, nil
   119  }
   120  
   121  func getSptFromCertificate(oauthConfig adal.OAuthConfig, clientID, resource, certicatePath string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
   122  	certData, err := ioutil.ReadFile(certificatePath)
   123  	if err != nil {
   124  		return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
   125  	}
   126  
   127  	certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
   128  	if err != nil {
   129  		return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
   130  	}
   131  
   132  	spt, _ := adal.NewServicePrincipalTokenFromCertificate(
   133  		oauthConfig,
   134  		clientID,
   135  		certificate,
   136  		rsaPrivateKey,
   137  		resource,
   138  		callbacks...)
   139  
   140  	return spt, nil
   141  }
   142  
   143  func getSptFromDeviceFlow(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
   144  	oauthClient := &autorest.Client{}
   145  	deviceCode, err := adal.InitiateDeviceAuth(oauthClient, oauthConfig, clientID, resource)
   146  	if err != nil {
   147  		return nil, fmt.Errorf("failed to start device auth flow: %s", err)
   148  	}
   149  
   150  	fmt.Println(*deviceCode.Message)
   151  
   152  	token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
   153  	if err != nil {
   154  		return nil, fmt.Errorf("failed to finish device auth flow: %s", err)
   155  	}
   156  
   157  	spt, err := adal.NewServicePrincipalTokenFromManualToken(
   158  		oauthConfig,
   159  		clientID,
   160  		resource,
   161  		*token,
   162  		callbacks...)
   163  	if err != nil {
   164  		return nil, fmt.Errorf("failed to get oauth token from device flow: %v", err)
   165  	}
   166  
   167  	return spt, nil
   168  }
   169  
   170  func printResourceGroups(client *autorest.Client) error {
   171  	p := map[string]interface{}{"subscription-id": subscriptionID}
   172  	q := map[string]interface{}{"api-version": apiVersion}
   173  
   174  	req, _ := autorest.Prepare(&http.Request{},
   175  		autorest.AsGet(),
   176  		autorest.WithBaseURL(resourceGroupURLTemplate),
   177  		autorest.WithPathParameters("/subscriptions/{subscription-id}/resourcegroups", p),
   178  		autorest.WithQueryParameters(q))
   179  
   180  	resp, err := autorest.SendWithSender(client, req)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	value := struct {
   186  		ResourceGroups []struct {
   187  			Name string `json:"name"`
   188  		} `json:"value"`
   189  	}{}
   190  
   191  	defer resp.Body.Close()
   192  	dec := json.NewDecoder(resp.Body)
   193  	err = dec.Decode(&value)
   194  	if err != nil {
   195  		return err
   196  	}
   197  
   198  	var groupNames = make([]string, len(value.ResourceGroups))
   199  	for i, name := range value.ResourceGroups {
   200  		groupNames[i] = name.Name
   201  	}
   202  
   203  	log.Println("Groups:", strings.Join(groupNames, ", "))
   204  	return err
   205  }
   206  
   207  func saveToken(spt adal.Token) {
   208  	if tokenCachePath != "" {
   209  		err := adal.SaveToken(tokenCachePath, 0600, spt)
   210  		if err != nil {
   211  			log.Println("error saving token", err)
   212  		} else {
   213  			log.Println("saved token to", tokenCachePath)
   214  		}
   215  	}
   216  }
   217  
   218  func main() {
   219  	var spt *adal.ServicePrincipalToken
   220  	var err error
   221  
   222  	callback := func(t adal.Token) error {
   223  		log.Println("refresh callback was called")
   224  		saveToken(t)
   225  		return nil
   226  	}
   227  
   228  	oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, tenantID)
   229  	if err != nil {
   230  		panic(err)
   231  	}
   232  
   233  	if tokenCachePath != "" {
   234  		log.Println("tokenCachePath specified; attempting to load from", tokenCachePath)
   235  		spt, err = getSptFromCachedToken(*oauthConfig, applicationID, resource, callback)
   236  		if err != nil {
   237  			spt = nil // just in case, this is the condition below
   238  			log.Println("loading from cache failed:", err)
   239  		}
   240  	}
   241  
   242  	if spt == nil {
   243  		log.Println("authenticating via 'mode'", mode)
   244  		switch mode {
   245  		case "device":
   246  			spt, err = getSptFromDeviceFlow(*oauthConfig, applicationID, resource, callback)
   247  		case "certificate":
   248  			spt, err = getSptFromCertificate(*oauthConfig, applicationID, resource, certificatePath, callback)
   249  		}
   250  		if err != nil {
   251  			log.Fatalln("failed to retrieve token:", err)
   252  		}
   253  
   254  		// should save it as soon as you get it since Refresh won't be called for some time
   255  		if tokenCachePath != "" {
   256  			saveToken(spt.Token())
   257  		}
   258  	}
   259  
   260  	client := &autorest.Client{}
   261  	client.Authorizer = autorest.NewBearerAuthorizer(spt)
   262  
   263  	printResourceGroups(client)
   264  
   265  	if forceRefresh {
   266  		err = spt.Refresh()
   267  		if err != nil {
   268  			panic(err)
   269  		}
   270  		printResourceGroups(client)
   271  	}
   272  }
   273  

View as plain text