...

Source file src/github.com/go-resty/resty/v2/middleware.go

Documentation: github.com/go-resty/resty/v2

     1  // Copyright (c) 2015-2021 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
     2  // resty source code and usage is governed by a MIT style
     3  // license that can be found in the LICENSE file.
     4  
     5  package resty
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"mime/multipart"
    14  	"net/http"
    15  	"net/url"
    16  	"os"
    17  	"path/filepath"
    18  	"reflect"
    19  	"strings"
    20  	"time"
    21  )
    22  
    23  const debugRequestLogKey = "__restyDebugRequestLog"
    24  
    25  //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
    26  // Request Middleware(s)
    27  //_______________________________________________________________________
    28  
    29  func parseRequestURL(c *Client, r *Request) error {
    30  	// GitHub #103 Path Params
    31  	if len(r.PathParams) > 0 {
    32  		for p, v := range r.PathParams {
    33  			r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1)
    34  		}
    35  	}
    36  	if len(c.PathParams) > 0 {
    37  		for p, v := range c.PathParams {
    38  			r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1)
    39  		}
    40  	}
    41  
    42  	// Parsing request URL
    43  	reqURL, err := url.Parse(r.URL)
    44  	if err != nil {
    45  		return err
    46  	}
    47  
    48  	// If Request.URL is relative path then added c.HostURL into
    49  	// the request URL otherwise Request.URL will be used as-is
    50  	if !reqURL.IsAbs() {
    51  		r.URL = reqURL.String()
    52  		if len(r.URL) > 0 && r.URL[0] != '/' {
    53  			r.URL = "/" + r.URL
    54  		}
    55  
    56  		reqURL, err = url.Parse(c.HostURL + r.URL)
    57  		if err != nil {
    58  			return err
    59  		}
    60  	}
    61  
    62  	// GH #407 && #318
    63  	if reqURL.Scheme == "" && len(c.scheme) > 0 {
    64  		reqURL.Scheme = c.scheme
    65  	}
    66  
    67  	// Adding Query Param
    68  	query := make(url.Values)
    69  	for k, v := range c.QueryParam {
    70  		for _, iv := range v {
    71  			query.Add(k, iv)
    72  		}
    73  	}
    74  
    75  	for k, v := range r.QueryParam {
    76  		// remove query param from client level by key
    77  		// since overrides happens for that key in the request
    78  		query.Del(k)
    79  
    80  		for _, iv := range v {
    81  			query.Add(k, iv)
    82  		}
    83  	}
    84  
    85  	// GitHub #123 Preserve query string order partially.
    86  	// Since not feasible in `SetQuery*` resty methods, because
    87  	// standard package `url.Encode(...)` sorts the query params
    88  	// alphabetically
    89  	if len(query) > 0 {
    90  		if IsStringEmpty(reqURL.RawQuery) {
    91  			reqURL.RawQuery = query.Encode()
    92  		} else {
    93  			reqURL.RawQuery = reqURL.RawQuery + "&" + query.Encode()
    94  		}
    95  	}
    96  
    97  	r.URL = reqURL.String()
    98  
    99  	return nil
   100  }
   101  
   102  func parseRequestHeader(c *Client, r *Request) error {
   103  	hdr := make(http.Header)
   104  	for k := range c.Header {
   105  		hdr[k] = append(hdr[k], c.Header[k]...)
   106  	}
   107  
   108  	for k := range r.Header {
   109  		hdr.Del(k)
   110  		hdr[k] = append(hdr[k], r.Header[k]...)
   111  	}
   112  
   113  	if IsStringEmpty(hdr.Get(hdrUserAgentKey)) {
   114  		hdr.Set(hdrUserAgentKey, hdrUserAgentValue)
   115  	}
   116  
   117  	ct := hdr.Get(hdrContentTypeKey)
   118  	if IsStringEmpty(hdr.Get(hdrAcceptKey)) && !IsStringEmpty(ct) &&
   119  		(IsJSONType(ct) || IsXMLType(ct)) {
   120  		hdr.Set(hdrAcceptKey, hdr.Get(hdrContentTypeKey))
   121  	}
   122  
   123  	r.Header = hdr
   124  
   125  	return nil
   126  }
   127  
   128  func parseRequestBody(c *Client, r *Request) (err error) {
   129  	if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
   130  		// Handling Multipart
   131  		if r.isMultiPart && !(r.Method == MethodPatch) {
   132  			if err = handleMultipart(c, r); err != nil {
   133  				return
   134  			}
   135  
   136  			goto CL
   137  		}
   138  
   139  		// Handling Form Data
   140  		if len(c.FormData) > 0 || len(r.FormData) > 0 {
   141  			handleFormData(c, r)
   142  
   143  			goto CL
   144  		}
   145  
   146  		// Handling Request body
   147  		if r.Body != nil {
   148  			handleContentType(c, r)
   149  
   150  			if err = handleRequestBody(c, r); err != nil {
   151  				return
   152  			}
   153  		}
   154  	}
   155  
   156  CL:
   157  	// by default resty won't set content length, you can if you want to :)
   158  	if (c.setContentLength || r.setContentLength) && r.bodyBuf != nil {
   159  		r.Header.Set(hdrContentLengthKey, fmt.Sprintf("%d", r.bodyBuf.Len()))
   160  	}
   161  
   162  	return
   163  }
   164  
   165  func createHTTPRequest(c *Client, r *Request) (err error) {
   166  	if r.bodyBuf == nil {
   167  		if reader, ok := r.Body.(io.Reader); ok {
   168  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
   169  		} else if c.setContentLength || r.setContentLength {
   170  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody)
   171  		} else {
   172  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil)
   173  		}
   174  	} else {
   175  		r.RawRequest, err = http.NewRequest(r.Method, r.URL, r.bodyBuf)
   176  	}
   177  
   178  	if err != nil {
   179  		return
   180  	}
   181  
   182  	// Assign close connection option
   183  	r.RawRequest.Close = c.closeConnection
   184  
   185  	// Add headers into http request
   186  	r.RawRequest.Header = r.Header
   187  
   188  	// Add cookies from client instance into http request
   189  	for _, cookie := range c.Cookies {
   190  		r.RawRequest.AddCookie(cookie)
   191  	}
   192  
   193  	// Add cookies from request instance into http request
   194  	for _, cookie := range r.Cookies {
   195  		r.RawRequest.AddCookie(cookie)
   196  	}
   197  
   198  	// Enable trace
   199  	if c.trace || r.trace {
   200  		r.clientTrace = &clientTrace{}
   201  		r.ctx = r.clientTrace.createContext(r.Context())
   202  	}
   203  
   204  	// Use context if it was specified
   205  	if r.ctx != nil {
   206  		r.RawRequest = r.RawRequest.WithContext(r.ctx)
   207  	}
   208  
   209  	bodyCopy, err := getBodyCopy(r)
   210  	if err != nil {
   211  		return err
   212  	}
   213  
   214  	// assign get body func for the underlying raw request instance
   215  	r.RawRequest.GetBody = func() (io.ReadCloser, error) {
   216  		if bodyCopy != nil {
   217  			return ioutil.NopCloser(bytes.NewReader(bodyCopy.Bytes())), nil
   218  		}
   219  		return nil, nil
   220  	}
   221  
   222  	return
   223  }
   224  
   225  func addCredentials(c *Client, r *Request) error {
   226  	var isBasicAuth bool
   227  	// Basic Auth
   228  	if r.UserInfo != nil { // takes precedence
   229  		r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
   230  		isBasicAuth = true
   231  	} else if c.UserInfo != nil {
   232  		r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password)
   233  		isBasicAuth = true
   234  	}
   235  
   236  	if !c.DisableWarn {
   237  		if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
   238  			c.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
   239  		}
   240  	}
   241  
   242  	// Set the Authorization Header Scheme
   243  	var authScheme string
   244  	if !IsStringEmpty(r.AuthScheme) {
   245  		authScheme = r.AuthScheme
   246  	} else if !IsStringEmpty(c.AuthScheme) {
   247  		authScheme = c.AuthScheme
   248  	} else {
   249  		authScheme = "Bearer"
   250  	}
   251  
   252  	// Build the Token Auth header
   253  	if !IsStringEmpty(r.Token) { // takes precedence
   254  		r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+r.Token)
   255  	} else if !IsStringEmpty(c.Token) {
   256  		r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+c.Token)
   257  	}
   258  
   259  	return nil
   260  }
   261  
   262  func requestLogger(c *Client, r *Request) error {
   263  	if c.Debug {
   264  		rr := r.RawRequest
   265  		rl := &RequestLog{Header: copyHeaders(rr.Header), Body: r.fmtBodyString(c.debugBodySizeLimit)}
   266  		if c.requestLog != nil {
   267  			if err := c.requestLog(rl); err != nil {
   268  				return err
   269  			}
   270  		}
   271  		// fmt.Sprintf("COOKIES:\n%s\n", composeCookies(c.GetClient().Jar, *rr.URL)) +
   272  
   273  		reqLog := "\n==============================================================================\n" +
   274  			"~~~ REQUEST ~~~\n" +
   275  			fmt.Sprintf("%s  %s  %s\n", r.Method, rr.URL.RequestURI(), rr.Proto) +
   276  			fmt.Sprintf("HOST   : %s\n", rr.URL.Host) +
   277  			fmt.Sprintf("HEADERS:\n%s\n", composeHeaders(c, r, rl.Header)) +
   278  			fmt.Sprintf("BODY   :\n%v\n", rl.Body) +
   279  			"------------------------------------------------------------------------------\n"
   280  
   281  		r.initValuesMap()
   282  		r.values[debugRequestLogKey] = reqLog
   283  	}
   284  
   285  	return nil
   286  }
   287  
   288  //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
   289  // Response Middleware(s)
   290  //_______________________________________________________________________
   291  
   292  func responseLogger(c *Client, res *Response) error {
   293  	if c.Debug {
   294  		rl := &ResponseLog{Header: copyHeaders(res.Header()), Body: res.fmtBodyString(c.debugBodySizeLimit)}
   295  		if c.responseLog != nil {
   296  			if err := c.responseLog(rl); err != nil {
   297  				return err
   298  			}
   299  		}
   300  
   301  		debugLog := res.Request.values[debugRequestLogKey].(string)
   302  		debugLog += "~~~ RESPONSE ~~~\n" +
   303  			fmt.Sprintf("STATUS       : %s\n", res.Status()) +
   304  			fmt.Sprintf("PROTO        : %s\n", res.RawResponse.Proto) +
   305  			fmt.Sprintf("RECEIVED AT  : %v\n", res.ReceivedAt().Format(time.RFC3339Nano)) +
   306  			fmt.Sprintf("TIME DURATION: %v\n", res.Time()) +
   307  			"HEADERS      :\n" +
   308  			composeHeaders(c, res.Request, rl.Header) + "\n"
   309  		if res.Request.isSaveResponse {
   310  			debugLog += "BODY         :\n***** RESPONSE WRITTEN INTO FILE *****\n"
   311  		} else {
   312  			debugLog += fmt.Sprintf("BODY         :\n%v\n", rl.Body)
   313  		}
   314  		debugLog += "==============================================================================\n"
   315  
   316  		c.log.Debugf("%s", debugLog)
   317  	}
   318  
   319  	return nil
   320  }
   321  
   322  func parseResponseBody(c *Client, res *Response) (err error) {
   323  	if res.StatusCode() == http.StatusNoContent {
   324  		return
   325  	}
   326  	// Handles only JSON or XML content type
   327  	ct := firstNonEmpty(res.Request.forceContentType, res.Header().Get(hdrContentTypeKey), res.Request.fallbackContentType)
   328  	if IsJSONType(ct) || IsXMLType(ct) {
   329  		// HTTP status code > 199 and < 300, considered as Result
   330  		if res.IsSuccess() {
   331  			res.Request.Error = nil
   332  			if res.Request.Result != nil {
   333  				err = Unmarshalc(c, ct, res.body, res.Request.Result)
   334  				return
   335  			}
   336  		}
   337  
   338  		// HTTP status code > 399, considered as Error
   339  		if res.IsError() {
   340  			// global error interface
   341  			if res.Request.Error == nil && c.Error != nil {
   342  				res.Request.Error = reflect.New(c.Error).Interface()
   343  			}
   344  
   345  			if res.Request.Error != nil {
   346  				err = Unmarshalc(c, ct, res.body, res.Request.Error)
   347  			}
   348  		}
   349  	}
   350  
   351  	return
   352  }
   353  
   354  func handleMultipart(c *Client, r *Request) (err error) {
   355  	r.bodyBuf = acquireBuffer()
   356  	w := multipart.NewWriter(r.bodyBuf)
   357  
   358  	for k, v := range c.FormData {
   359  		for _, iv := range v {
   360  			if err = w.WriteField(k, iv); err != nil {
   361  				return err
   362  			}
   363  		}
   364  	}
   365  
   366  	for k, v := range r.FormData {
   367  		for _, iv := range v {
   368  			if strings.HasPrefix(k, "@") { // file
   369  				err = addFile(w, k[1:], iv)
   370  				if err != nil {
   371  					return
   372  				}
   373  			} else { // form value
   374  				if err = w.WriteField(k, iv); err != nil {
   375  					return err
   376  				}
   377  			}
   378  		}
   379  	}
   380  
   381  	// #21 - adding io.Reader support
   382  	if len(r.multipartFiles) > 0 {
   383  		for _, f := range r.multipartFiles {
   384  			err = addFileReader(w, f)
   385  			if err != nil {
   386  				return
   387  			}
   388  		}
   389  	}
   390  
   391  	// GitHub #130 adding multipart field support with content type
   392  	if len(r.multipartFields) > 0 {
   393  		for _, mf := range r.multipartFields {
   394  			if err = addMultipartFormField(w, mf); err != nil {
   395  				return
   396  			}
   397  		}
   398  	}
   399  
   400  	r.Header.Set(hdrContentTypeKey, w.FormDataContentType())
   401  	err = w.Close()
   402  
   403  	return
   404  }
   405  
   406  func handleFormData(c *Client, r *Request) {
   407  	formData := url.Values{}
   408  
   409  	for k, v := range c.FormData {
   410  		for _, iv := range v {
   411  			formData.Add(k, iv)
   412  		}
   413  	}
   414  
   415  	for k, v := range r.FormData {
   416  		// remove form data field from client level by key
   417  		// since overrides happens for that key in the request
   418  		formData.Del(k)
   419  
   420  		for _, iv := range v {
   421  			formData.Add(k, iv)
   422  		}
   423  	}
   424  
   425  	r.bodyBuf = bytes.NewBuffer([]byte(formData.Encode()))
   426  	r.Header.Set(hdrContentTypeKey, formContentType)
   427  	r.isFormData = true
   428  }
   429  
   430  func handleContentType(c *Client, r *Request) {
   431  	contentType := r.Header.Get(hdrContentTypeKey)
   432  	if IsStringEmpty(contentType) {
   433  		contentType = DetectContentType(r.Body)
   434  		r.Header.Set(hdrContentTypeKey, contentType)
   435  	}
   436  }
   437  
   438  func handleRequestBody(c *Client, r *Request) (err error) {
   439  	var bodyBytes []byte
   440  	contentType := r.Header.Get(hdrContentTypeKey)
   441  	kind := kindOf(r.Body)
   442  	r.bodyBuf = nil
   443  
   444  	if reader, ok := r.Body.(io.Reader); ok {
   445  		if c.setContentLength || r.setContentLength { // keep backward compatibility
   446  			r.bodyBuf = acquireBuffer()
   447  			_, err = r.bodyBuf.ReadFrom(reader)
   448  			r.Body = nil
   449  		} else {
   450  			// Otherwise buffer less processing for `io.Reader`, sounds good.
   451  			return
   452  		}
   453  	} else if b, ok := r.Body.([]byte); ok {
   454  		bodyBytes = b
   455  	} else if s, ok := r.Body.(string); ok {
   456  		bodyBytes = []byte(s)
   457  	} else if IsJSONType(contentType) &&
   458  		(kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
   459  		r.bodyBuf, err = jsonMarshal(c, r, r.Body)
   460  		if err != nil {
   461  			return
   462  		}
   463  	} else if IsXMLType(contentType) && (kind == reflect.Struct) {
   464  		bodyBytes, err = c.XMLMarshal(r.Body)
   465  		if err != nil {
   466  			return
   467  		}
   468  	}
   469  
   470  	if bodyBytes == nil && r.bodyBuf == nil {
   471  		err = errors.New("unsupported 'Body' type/value")
   472  	}
   473  
   474  	// if any errors during body bytes handling, return it
   475  	if err != nil {
   476  		return
   477  	}
   478  
   479  	// []byte into Buffer
   480  	if bodyBytes != nil && r.bodyBuf == nil {
   481  		r.bodyBuf = acquireBuffer()
   482  		_, _ = r.bodyBuf.Write(bodyBytes)
   483  	}
   484  
   485  	return
   486  }
   487  
   488  func saveResponseIntoFile(c *Client, res *Response) error {
   489  	if res.Request.isSaveResponse {
   490  		file := ""
   491  
   492  		if len(c.outputDirectory) > 0 && !filepath.IsAbs(res.Request.outputFile) {
   493  			file += c.outputDirectory + string(filepath.Separator)
   494  		}
   495  
   496  		file = filepath.Clean(file + res.Request.outputFile)
   497  		if err := createDirectory(filepath.Dir(file)); err != nil {
   498  			return err
   499  		}
   500  
   501  		outFile, err := os.Create(file)
   502  		if err != nil {
   503  			return err
   504  		}
   505  		defer closeq(outFile)
   506  
   507  		// io.Copy reads maximum 32kb size, it is perfect for large file download too
   508  		defer closeq(res.RawResponse.Body)
   509  
   510  		written, err := io.Copy(outFile, res.RawResponse.Body)
   511  		if err != nil {
   512  			return err
   513  		}
   514  
   515  		res.size = written
   516  	}
   517  
   518  	return nil
   519  }
   520  
   521  func getBodyCopy(r *Request) (*bytes.Buffer, error) {
   522  	// If r.bodyBuf present, return the copy
   523  	if r.bodyBuf != nil {
   524  		return bytes.NewBuffer(r.bodyBuf.Bytes()), nil
   525  	}
   526  
   527  	// Maybe body is `io.Reader`.
   528  	// Note: Resty user have to watchout for large body size of `io.Reader`
   529  	if r.RawRequest.Body != nil {
   530  		b, err := ioutil.ReadAll(r.RawRequest.Body)
   531  		if err != nil {
   532  			return nil, err
   533  		}
   534  
   535  		// Restore the Body
   536  		closeq(r.RawRequest.Body)
   537  		r.RawRequest.Body = ioutil.NopCloser(bytes.NewBuffer(b))
   538  
   539  		// Return the Body bytes
   540  		return bytes.NewBuffer(b), nil
   541  	}
   542  	return nil, nil
   543  }
   544  

View as plain text