...
1 package ntlmssp
2
3 import (
4 "bytes"
5 "encoding/base64"
6 "io"
7 "io/ioutil"
8 "net/http"
9 "strings"
10 )
11
12
13
14 func GetDomain(user string) (string, string, bool) {
15 domain := ""
16 domainNeeded := false
17
18 if strings.Contains(user, "\\") {
19 ucomponents := strings.SplitN(user, "\\", 2)
20 domain = ucomponents[0]
21 user = ucomponents[1]
22 domainNeeded = true
23 } else if strings.Contains(user, "@") {
24 domainNeeded = false
25 } else {
26 domainNeeded = true
27 }
28 return user, domain, domainNeeded
29 }
30
31
32
33 type Negotiator struct{ http.RoundTripper }
34
35
36
37 func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
38
39 rt := l.RoundTripper
40 if rt == nil {
41 rt = http.DefaultTransport
42 }
43
44 reqauth := authheader(req.Header.Values("Authorization"))
45 if !reqauth.IsBasic() {
46 return rt.RoundTrip(req)
47 }
48 reqauthBasic := reqauth.Basic()
49
50 body := bytes.Buffer{}
51 if req.Body != nil {
52 _, err = body.ReadFrom(req.Body)
53 if err != nil {
54 return nil, err
55 }
56
57 req.Body.Close()
58 req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
59 }
60
61
62 req.Header.Del("Authorization")
63 res, err = rt.RoundTrip(req)
64 if err != nil {
65 return nil, err
66 }
67 if res.StatusCode != http.StatusUnauthorized {
68 return res, err
69 }
70 resauth := authheader(res.Header.Values("Www-Authenticate"))
71 if !resauth.IsNegotiate() && !resauth.IsNTLM() {
72
73 req.Header.Set("Authorization", string(reqauthBasic))
74 io.Copy(ioutil.Discard, res.Body)
75 res.Body.Close()
76 req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
77
78 res, err = rt.RoundTrip(req)
79 if err != nil {
80 return nil, err
81 }
82 if res.StatusCode != http.StatusUnauthorized {
83 return res, err
84 }
85 resauth = authheader(res.Header.Values("Www-Authenticate"))
86 }
87
88 if resauth.IsNegotiate() || resauth.IsNTLM() {
89
90 io.Copy(ioutil.Discard, res.Body)
91 res.Body.Close()
92
93
94 u, p, err := reqauth.GetBasicCreds()
95 if err != nil {
96 return nil, err
97 }
98
99
100 domain := ""
101 u, domain, domainNeeded := GetDomain(u)
102
103
104 negotiateMessage, err := NewNegotiateMessage(domain, "")
105 if err != nil {
106 return nil, err
107 }
108 if resauth.IsNTLM() {
109 req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
110 } else {
111 req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
112 }
113
114 req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
115
116 res, err = rt.RoundTrip(req)
117 if err != nil {
118 return nil, err
119 }
120
121
122 resauth = authheader(res.Header.Values("Www-Authenticate"))
123 challengeMessage, err := resauth.GetData()
124 if err != nil {
125 return nil, err
126 }
127 if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
128
129 return res, nil
130 }
131 io.Copy(ioutil.Discard, res.Body)
132 res.Body.Close()
133
134
135 authenticateMessage, err := ProcessChallenge(challengeMessage, u, p, domainNeeded)
136 if err != nil {
137 return nil, err
138 }
139 if resauth.IsNTLM() {
140 req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
141 } else {
142 req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
143 }
144
145 req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
146
147 return rt.RoundTrip(req)
148 }
149
150 return res, err
151 }
152
View as plain text