...
1
2
3
4
5
6
7
8
9
10
11
12
13 package chttp
14
15 import (
16 "context"
17 "fmt"
18 "net/http"
19 "net/http/cookiejar"
20 "strings"
21 "time"
22
23 "golang.org/x/net/publicsuffix"
24
25 kivik "github.com/go-kivik/kivik/v4"
26 )
27
28
29
30
31
32 type cookieAuth struct {
33 Username string `json:"name"`
34 Password string `json:"password"`
35
36 client *Client
37
38
39 transport http.RoundTripper
40 }
41
42 var (
43 _ authenticator = &cookieAuth{}
44 _ kivik.Option = (*cookieAuth)(nil)
45 )
46
47 func (a *cookieAuth) Apply(target interface{}) {
48 if auth, ok := target.(*authenticator); ok {
49
50
51 *auth = &cookieAuth{
52 Username: a.Username,
53 Password: a.Password,
54 }
55 }
56 }
57
58 func (a *cookieAuth) String() string {
59 return fmt.Sprintf("[CookieAuth{user:%s,pass:%s}]", a.Username, strings.Repeat("*", len(a.Password)))
60 }
61
62
63 func (a *cookieAuth) Authenticate(c *Client) error {
64 a.client = c
65 a.setCookieJar()
66 a.transport = c.Transport
67 if a.transport == nil {
68 a.transport = http.DefaultTransport
69 }
70 c.Transport = a
71 return nil
72 }
73
74
75 func (a *cookieAuth) shouldAuth(req *http.Request) bool {
76 if _, err := req.Cookie(kivik.SessionCookieName); err == nil {
77 return false
78 }
79 cookie := a.Cookie()
80 if cookie == nil {
81 return true
82 }
83 if !cookie.Expires.IsZero() {
84 return cookie.Expires.Before(time.Now().Add(time.Minute))
85 }
86
87
88
89
90
91 return false
92 }
93
94
95 func (a *cookieAuth) Cookie() *http.Cookie {
96 if a.client == nil {
97 return nil
98 }
99 for _, cookie := range a.client.Jar.Cookies(a.client.dsn) {
100 if cookie.Name == kivik.SessionCookieName {
101 return cookie
102 }
103 }
104 return nil
105 }
106
107 var authInProgress = &struct{ name string }{"in progress"}
108
109
110
111
112
113 func (a *cookieAuth) RoundTrip(req *http.Request) (*http.Response, error) {
114 if err := a.authenticate(req); err != nil {
115 return nil, err
116 }
117
118 res, err := a.transport.RoundTrip(req)
119 if err != nil {
120 return res, err
121 }
122
123 if res != nil && res.StatusCode == http.StatusUnauthorized {
124 if cookie := a.Cookie(); cookie != nil {
125
126 cookie.Expires = time.Now().AddDate(0, 0, -1)
127 a.client.Jar.SetCookies(a.client.dsn, []*http.Cookie{cookie})
128 }
129 }
130 return res, nil
131 }
132
133 func (a *cookieAuth) authenticate(req *http.Request) error {
134 ctx := req.Context()
135 if inProg, _ := ctx.Value(authInProgress).(bool); inProg {
136 return nil
137 }
138 if !a.shouldAuth(req) {
139 return nil
140 }
141 a.client.authMU.Lock()
142 defer a.client.authMU.Unlock()
143 if c := a.Cookie(); c != nil {
144
145 req.AddCookie(c)
146 return nil
147 }
148 ctx = context.WithValue(ctx, authInProgress, true)
149 opts := &Options{
150 GetBody: BodyEncoder(a),
151 Header: http.Header{
152 HeaderIdempotencyKey: []string{},
153 },
154 }
155 if _, err := a.client.DoError(ctx, http.MethodPost, "/_session", opts); err != nil {
156 return err
157 }
158 if c := a.Cookie(); c != nil {
159 req.AddCookie(c)
160 }
161 return nil
162 }
163
164 func (a *cookieAuth) setCookieJar() {
165
166 if a.client.Jar != nil {
167 return
168 }
169
170 jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
171 a.client.Jar = jar
172 }
173
View as plain text