1 package basic
2
3 import (
4 "context"
5 "encoding/base64"
6 "fmt"
7 "testing"
8
9 httptransport "github.com/go-kit/kit/transport/http"
10 )
11
12 func TestWithBasicAuth(t *testing.T) {
13 requiredUser := "test-user"
14 requiredPassword := "test-pass"
15 realm := "test realm"
16
17 type want struct {
18 result interface{}
19 err error
20 }
21 tests := []struct {
22 name string
23 authHeader interface{}
24 want want
25 }{
26 {"Isn't valid with nil header", nil, want{nil, AuthError{realm}}},
27 {"Isn't valid with non-string header", 42, want{nil, AuthError{realm}}},
28 {"Isn't valid without authHeader", "", want{nil, AuthError{realm}}},
29 {"Isn't valid for wrong user", makeAuthString("wrong-user", requiredPassword), want{nil, AuthError{realm}}},
30 {"Isn't valid for wrong password", makeAuthString(requiredUser, "wrong-password"), want{nil, AuthError{realm}}},
31 {"Is valid for correct creds", makeAuthString(requiredUser, requiredPassword), want{true, nil}},
32 }
33 for _, tt := range tests {
34 t.Run(tt.name, func(t *testing.T) {
35 ctx := context.WithValue(context.TODO(), httptransport.ContextKeyRequestAuthorization, tt.authHeader)
36
37 result, err := AuthMiddleware(requiredUser, requiredPassword, realm)(passedValidation)(ctx, nil)
38 if result != tt.want.result || err != tt.want.err {
39 t.Errorf("WithBasicAuth() = result: %v, err: %v, want result: %v, want error: %v", result, err, tt.want.result, tt.want.err)
40 }
41 })
42 }
43 }
44
45 func makeAuthString(user string, password string) string {
46 data := []byte(fmt.Sprintf("%s:%s", user, password))
47 return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(data))
48 }
49
50 func passedValidation(ctx context.Context, request interface{}) (response interface{}, err error) {
51 return true, nil
52 }
53
View as plain text