...

Source file src/github.com/Azure/go-ntlmssp/negotiator.go

Documentation: github.com/Azure/go-ntlmssp

     1  package ntlmssp
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"strings"
    10  )
    11  
    12  // GetDomain : parse domain name from based on slashes in the input
    13  // Need to check for upn as well
    14  func GetDomain(user string) (string, string, bool) {
    15  	domain := ""
    16  	domainNeeded := false
    17  
    18  	if strings.Contains(user, "\\") {
    19  		ucomponents := strings.SplitN(user, "\\", 2)
    20  		domain = ucomponents[0]
    21  		user = ucomponents[1]
    22  		domainNeeded = true
    23  	} else if strings.Contains(user, "@") {
    24  		domainNeeded = false
    25  	} else {
    26  		domainNeeded = true
    27  	}
    28  	return user, domain, domainNeeded
    29  }
    30  
    31  //Negotiator is a http.Roundtripper decorator that automatically
    32  //converts basic authentication to NTLM/Negotiate authentication when appropriate.
    33  type Negotiator struct{ http.RoundTripper }
    34  
    35  //RoundTrip sends the request to the server, handling any authentication
    36  //re-sends as needed.
    37  func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
    38  	// Use default round tripper if not provided
    39  	rt := l.RoundTripper
    40  	if rt == nil {
    41  		rt = http.DefaultTransport
    42  	}
    43  	// If it is not basic auth, just round trip the request as usual
    44  	reqauth := authheader(req.Header.Values("Authorization"))
    45  	if !reqauth.IsBasic() {
    46  		return rt.RoundTrip(req)
    47  	}
    48  	reqauthBasic := reqauth.Basic()
    49  	// Save request body
    50  	body := bytes.Buffer{}
    51  	if req.Body != nil {
    52  		_, err = body.ReadFrom(req.Body)
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  
    57  		req.Body.Close()
    58  		req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
    59  	}
    60  	// first try anonymous, in case the server still finds us
    61  	// authenticated from previous traffic
    62  	req.Header.Del("Authorization")
    63  	res, err = rt.RoundTrip(req)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	if res.StatusCode != http.StatusUnauthorized {
    68  		return res, err
    69  	}
    70  	resauth := authheader(res.Header.Values("Www-Authenticate"))
    71  	if !resauth.IsNegotiate() && !resauth.IsNTLM() {
    72  		// Unauthorized, Negotiate not requested, let's try with basic auth
    73  		req.Header.Set("Authorization", string(reqauthBasic))
    74  		io.Copy(ioutil.Discard, res.Body)
    75  		res.Body.Close()
    76  		req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
    77  
    78  		res, err = rt.RoundTrip(req)
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  		if res.StatusCode != http.StatusUnauthorized {
    83  			return res, err
    84  		}
    85  		resauth = authheader(res.Header.Values("Www-Authenticate"))
    86  	}
    87  
    88  	if resauth.IsNegotiate() || resauth.IsNTLM() {
    89  		// 401 with request:Basic and response:Negotiate
    90  		io.Copy(ioutil.Discard, res.Body)
    91  		res.Body.Close()
    92  
    93  		// recycle credentials
    94  		u, p, err := reqauth.GetBasicCreds()
    95  		if err != nil {
    96  			return nil, err
    97  		}
    98  
    99  		// get domain from username
   100  		domain := ""
   101  		u, domain, domainNeeded := GetDomain(u)
   102  
   103  		// send negotiate
   104  		negotiateMessage, err := NewNegotiateMessage(domain, "")
   105  		if err != nil {
   106  			return nil, err
   107  		}
   108  		if resauth.IsNTLM() {
   109  			req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
   110  		} else {
   111  			req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
   112  		}
   113  
   114  		req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
   115  
   116  		res, err = rt.RoundTrip(req)
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  
   121  		// receive challenge?
   122  		resauth = authheader(res.Header.Values("Www-Authenticate"))
   123  		challengeMessage, err := resauth.GetData()
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  		if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
   128  			// Negotiation failed, let client deal with response
   129  			return res, nil
   130  		}
   131  		io.Copy(ioutil.Discard, res.Body)
   132  		res.Body.Close()
   133  
   134  		// send authenticate
   135  		authenticateMessage, err := ProcessChallenge(challengeMessage, u, p, domainNeeded)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  		if resauth.IsNTLM() {
   140  			req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
   141  		} else {
   142  			req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
   143  		}
   144  
   145  		req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
   146  
   147  		return rt.RoundTrip(req)
   148  	}
   149  
   150  	return res, err
   151  }
   152  

View as plain text