     1  package storage
     3  // Copyright (c) Microsoft Corporation. All rights reserved.
     4  // Licensed under the MIT License. See License.txt in the project root for license information.
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"net/http"
    13  	"net/url"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  )
    19  const (
    20  	tablesURIPath                  = "/Tables"
    21  	nextTableQueryParameter        = "NextTableName"
    22  	headerNextPartitionKey         = "x-ms-continuation-NextPartitionKey"
    23  	headerNextRowKey               = "x-ms-continuation-NextRowKey"
    24  	nextPartitionKeyQueryParameter = "NextPartitionKey"
    25  	nextRowKeyQueryParameter       = "NextRowKey"
    26  )
    28  // TableAccessPolicy are used for SETTING table policies
    29  type TableAccessPolicy struct {
    30  	ID         string
    31  	StartTime  time.Time
    32  	ExpiryTime time.Time
    33  	CanRead    bool
    34  	CanAppend  bool
    35  	CanUpdate  bool
    36  	CanDelete  bool
    37  }
    39  // Table represents an Azure table.
    40  type Table struct {
    41  	tsc           *TableServiceClient
    42  	Name          string `json:"TableName"`
    43  	OdataEditLink string `json:"odata.editLink"`
    44  	OdataID       string `json:"odata.id"`
    45  	OdataMetadata string `json:"odata.metadata"`
    46  	OdataType     string `json:"odata.type"`
    47  }
    49  // EntityQueryResult contains the response from
    50  // ExecuteQuery and ExecuteQueryNextResults functions.
    51  type EntityQueryResult struct {
    52  	OdataMetadata string    `json:"odata.metadata"`
    53  	Entities      []*Entity `json:"value"`
    54  	QueryNextLink
    55  	table *Table
    56  }
    58  type continuationToken struct {
    59  	NextPartitionKey string
    60  	NextRowKey       string
    61  }
    63  func (t *Table) buildPath() string {
    64  	return fmt.Sprintf("/%s", t.Name)
    65  }
    67  func (t *Table) buildSpecificPath() string {
    68  	return fmt.Sprintf("%s('%s')", tablesURIPath, t.Name)
    69  }
    71  // Get gets the referenced table.
    72  // See: https://docs.microsoft.com/en-us/rest/api/storageservices/fileservices/querying-tables-and-entities
    73  func (t *Table) Get(timeout uint, ml MetadataLevel) error {
    74  	if ml == EmptyPayload {
    75  		return errEmptyPayload
    76  	}
    78  	query := url.Values{
    79  		"timeout": {strconv.FormatUint(uint64(timeout), 10)},
    80  	}
    81  	headers := t.tsc.client.getStandardHeaders()
    82  	headers[headerAccept] = string(ml)
    84  	uri := t.tsc.client.getEndpoint(tableServiceName, t.buildSpecificPath(), query)
    85  	resp, err := t.tsc.client.exec(http.MethodGet, uri, headers, nil, t.tsc.auth)
    86  	if err != nil {
    87  		return err
    88  	}
    89  	defer resp.Body.Close()
    91  	if err = checkRespCode(resp, []int{http.StatusOK}); err != nil {
    92  		return err
    93  	}
    95  	respBody, err := ioutil.ReadAll(resp.Body)
    96  	if err != nil {
    97  		return err
    98  	}
    99  	err = json.Unmarshal(respBody, t)
   100  	if err != nil {
   101  		return err
   102  	}
   103  	return nil
   104  }
   106  // Create creates the referenced table.
   107  // This function fails if the name is not compliant
   108  // with the specification or the tables already exists.
   109  // ml determines the level of detail of metadata in the operation response,
   110  // or no data at all.
   111  // See https://docs.microsoft.com/rest/api/storageservices/fileservices/create-table
   112  func (t *Table) Create(timeout uint, ml MetadataLevel, options *TableOptions) error {
   113  	uri := t.tsc.client.getEndpoint(tableServiceName, tablesURIPath, url.Values{
   114  		"timeout": {strconv.FormatUint(uint64(timeout), 10)},
   115  	})
   117  	type createTableRequest struct {
   118  		TableName string `json:"TableName"`
   119  	}
   120  	req := createTableRequest{TableName: t.Name}
   121  	buf := new(bytes.Buffer)
   122  	if err := json.NewEncoder(buf).Encode(req); err != nil {
   123  		return err
   124  	}
   126  	headers := t.tsc.client.getStandardHeaders()
   127  	headers = addReturnContentHeaders(headers, ml)
   128  	headers = addBodyRelatedHeaders(headers, buf.Len())
   129  	headers = options.addToHeaders(headers)
   131  	resp, err := t.tsc.client.exec(http.MethodPost, uri, headers, buf, t.tsc.auth)
   132  	if err != nil {
   133  		return err
   134  	}
   135  	defer resp.Body.Close()
   137  	if ml == EmptyPayload {
   138  		if err := checkRespCode(resp, []int{http.StatusNoContent}); err != nil {
   139  			return err
   140  		}
   141  	} else {
   142  		if err := checkRespCode(resp, []int{http.StatusCreated}); err != nil {
   143  			return err
   144  		}
   145  	}
   147  	if ml != EmptyPayload {
   148  		data, err := ioutil.ReadAll(resp.Body)
   149  		if err != nil {
   150  			return err
   151  		}
   152  		err = json.Unmarshal(data, t)
   153  		if err != nil {
   154  			return err
   155  		}
   156  	}
   158  	return nil
   159  }
   161  // Delete deletes the referenced table.
   162  // This function fails if the table is not present.
   163  // Be advised: Delete deletes all the entries that may be present.
   164  // See https://docs.microsoft.com/rest/api/storageservices/fileservices/delete-table
   165  func (t *Table) Delete(timeout uint, options *TableOptions) error {
   166  	uri := t.tsc.client.getEndpoint(tableServiceName, t.buildSpecificPath(), url.Values{
   167  		"timeout": {strconv.Itoa(int(timeout))},
   168  	})
   170  	headers := t.tsc.client.getStandardHeaders()
   171  	headers = addReturnContentHeaders(headers, EmptyPayload)
   172  	headers = options.addToHeaders(headers)
   174  	resp, err := t.tsc.client.exec(http.MethodDelete, uri, headers, nil, t.tsc.auth)
   175  	if err != nil {
   176  		return err
   177  	}
   178  	defer drainRespBody(resp)
   180  	return checkRespCode(resp, []int{http.StatusNoContent})
   181  }
   183  // QueryOptions includes options for a query entities operation.
   184  // Top, filter and select are OData query options.
   185  type QueryOptions struct {
   186  	Top       uint
   187  	Filter    string
   188  	Select    []string
   189  	RequestID string
   190  }
   192  func (options *QueryOptions) getParameters() (url.Values, map[string]string) {
   193  	query := url.Values{}
   194  	headers := map[string]string{}
   195  	if options != nil {
   196  		if options.Top > 0 {
   197  			query.Add(OdataTop, strconv.FormatUint(uint64(options.Top), 10))
   198  		}
   199  		if options.Filter != "" {
   200  			query.Add(OdataFilter, options.Filter)
   201  		}
   202  		if len(options.Select) > 0 {
   203  			query.Add(OdataSelect, strings.Join(options.Select, ","))
   204  		}
   205  		headers = addToHeaders(headers, "x-ms-client-request-id", options.RequestID)
   206  	}
   207  	return query, headers
   208  }
   210  // QueryEntities returns the entities in the table.
   211  // You can use query options defined by the OData Protocol specification.
   212  //
   213  // See: https://docs.microsoft.com/en-us/rest/api/storageservices/fileservices/query-entities
   214  func (t *Table) QueryEntities(timeout uint, ml MetadataLevel, options *QueryOptions) (*EntityQueryResult, error) {
   215  	if ml == EmptyPayload {
   216  		return nil, errEmptyPayload
   217  	}
   218  	query, headers := options.getParameters()
   219  	query = addTimeout(query, timeout)
   220  	uri := t.tsc.client.getEndpoint(tableServiceName, t.buildPath(), query)
   221  	return t.queryEntities(uri, headers, ml)
   222  }
   224  // NextResults returns the next page of results
   225  // from a QueryEntities or NextResults operation.
   226  //
   227  // See: https://docs.microsoft.com/en-us/rest/api/storageservices/fileservices/query-entities
   228  // See https://docs.microsoft.com/rest/api/storageservices/fileservices/query-timeout-and-pagination
   229  func (eqr *EntityQueryResult) NextResults(options *TableOptions) (*EntityQueryResult, error) {
   230  	if eqr == nil {
   231  		return nil, errNilPreviousResult
   232  	}
   233  	if eqr.NextLink == nil {
   234  		return nil, errNilNextLink
   235  	}
   236  	headers := options.addToHeaders(map[string]string{})
   237  	return eqr.table.queryEntities(*eqr.NextLink, headers, eqr.ml)
   238  }
   240  // SetPermissions sets up table ACL permissions
   241  // See https://docs.microsoft.com/rest/api/storageservices/fileservices/Set-Table-ACL
   242  func (t *Table) SetPermissions(tap []TableAccessPolicy, timeout uint, options *TableOptions) error {
   243  	params := url.Values{"comp": {"acl"},
   244  		"timeout": {strconv.Itoa(int(timeout))},
   245  	}
   247  	uri := t.tsc.client.getEndpoint(tableServiceName, t.Name, params)
   248  	headers := t.tsc.client.getStandardHeaders()
   249  	headers = options.addToHeaders(headers)
   251  	body, length, err := generateTableACLPayload(tap)
   252  	if err != nil {
   253  		return err
   254  	}
   255  	headers["Content-Length"] = strconv.Itoa(length)
   257  	resp, err := t.tsc.client.exec(http.MethodPut, uri, headers, body, t.tsc.auth)
   258  	if err != nil {
   259  		return err
   260  	}
   261  	defer drainRespBody(resp)
   263  	return checkRespCode(resp, []int{http.StatusNoContent})
   264  }
   266  func generateTableACLPayload(policies []TableAccessPolicy) (io.Reader, int, error) {
   267  	sil := SignedIdentifiers{
   268  		SignedIdentifiers: []SignedIdentifier{},
   269  	}
   270  	for _, tap := range policies {
   271  		permission := generateTablePermissions(&tap)
   272  		signedIdentifier := convertAccessPolicyToXMLStructs(tap.ID, tap.StartTime, tap.ExpiryTime, permission)
   273  		sil.SignedIdentifiers = append(sil.SignedIdentifiers, signedIdentifier)
   274  	}
   275  	return xmlMarshal(sil)
   276  }
   278  // GetPermissions gets the table ACL permissions
   279  // See https://docs.microsoft.com/rest/api/storageservices/fileservices/get-table-acl
   280  func (t *Table) GetPermissions(timeout int, options *TableOptions) ([]TableAccessPolicy, error) {
   281  	params := url.Values{"comp": {"acl"},
   282  		"timeout": {strconv.Itoa(int(timeout))},
   283  	}
   285  	uri := t.tsc.client.getEndpoint(tableServiceName, t.Name, params)
   286  	headers := t.tsc.client.getStandardHeaders()
   287  	headers = options.addToHeaders(headers)
   289  	resp, err := t.tsc.client.exec(http.MethodGet, uri, headers, nil, t.tsc.auth)
   290  	if err != nil {
   291  		return nil, err
   292  	}
   293  	defer resp.Body.Close()
   295  	if err = checkRespCode(resp, []int{http.StatusOK}); err != nil {
   296  		return nil, err
   297  	}
   299  	var ap AccessPolicy
   300  	err = xmlUnmarshal(resp.Body, &ap.SignedIdentifiersList)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  	return updateTableAccessPolicy(ap), nil
   305  }
   307  func (t *Table) queryEntities(uri string, headers map[string]string, ml MetadataLevel) (*EntityQueryResult, error) {
   308  	headers = mergeHeaders(headers, t.tsc.client.getStandardHeaders())
   309  	if ml != EmptyPayload {
   310  		headers[headerAccept] = string(ml)
   311  	}
   313  	resp, err := t.tsc.client.exec(http.MethodGet, uri, headers, nil, t.tsc.auth)
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  	defer resp.Body.Close()
   319  	if err = checkRespCode(resp, []int{http.StatusOK}); err != nil {
   320  		return nil, err
   321  	}
   323  	data, err := ioutil.ReadAll(resp.Body)
   324  	if err != nil {
   325  		return nil, err
   326  	}
   327  	var entities EntityQueryResult
   328  	err = json.Unmarshal(data, &entities)
   329  	if err != nil {
   330  		return nil, err
   331  	}
   333  	for i := range entities.Entities {
   334  		entities.Entities[i].Table = t
   335  	}
   336  	entities.table = t
   338  	contToken := extractContinuationTokenFromHeaders(resp.Header)
   339  	if contToken == nil {
   340  		entities.NextLink = nil
   341  	} else {
   342  		originalURI, err := url.Parse(uri)
   343  		if err != nil {
   344  			return nil, err
   345  		}
   346  		v := originalURI.Query()
   347  		if contToken.NextPartitionKey != "" {
   348  			v.Set(nextPartitionKeyQueryParameter, contToken.NextPartitionKey)
   349  		}
   350  		if contToken.NextRowKey != "" {
   351  			v.Set(nextRowKeyQueryParameter, contToken.NextRowKey)
   352  		}
   353  		newURI := t.tsc.client.getEndpoint(tableServiceName, t.buildPath(), v)
   354  		entities.NextLink = &newURI
   355  		entities.ml = ml
   356  	}
   358  	return &entities, nil
   359  }
   361  func extractContinuationTokenFromHeaders(h http.Header) *continuationToken {
   362  	ct := continuationToken{
   363  		NextPartitionKey: h.Get(headerNextPartitionKey),
   364  		NextRowKey:       h.Get(headerNextRowKey),
   365  	}
   367  	if ct.NextPartitionKey != "" || ct.NextRowKey != "" {
   368  		return &ct
   369  	}
   370  	return nil
   371  }
   373  func updateTableAccessPolicy(ap AccessPolicy) []TableAccessPolicy {
   374  	taps := []TableAccessPolicy{}
   375  	for _, policy := range ap.SignedIdentifiersList.SignedIdentifiers {
   376  		tap := TableAccessPolicy{
   377  			ID:         policy.ID,
   378  			StartTime:  policy.AccessPolicy.StartTime,
   379  			ExpiryTime: policy.AccessPolicy.ExpiryTime,
   380  		}
   381  		tap.CanRead = updatePermissions(policy.AccessPolicy.Permission, "r")
   382  		tap.CanAppend = updatePermissions(policy.AccessPolicy.Permission, "a")
   383  		tap.CanUpdate = updatePermissions(policy.AccessPolicy.Permission, "u")
   384  		tap.CanDelete = updatePermissions(policy.AccessPolicy.Permission, "d")
   386  		taps = append(taps, tap)
   387  	}
   388  	return taps
   389  }
   391  func generateTablePermissions(tap *TableAccessPolicy) (permissions string) {
   392  	// generate the permissions string (raud).
   393  	// still want the end user API to have bool flags.
   394  	permissions = ""
   396  	if tap.CanRead {
   397  		permissions += "r"
   398  	}
   400  	if tap.CanAppend {
   401  		permissions += "a"
   402  	}
   404  	if tap.CanUpdate {
   405  		permissions += "u"
   406  	}
   408  	if tap.CanDelete {
   409  		permissions += "d"
   410  	}
   411  	return permissions
   412  }

