1 package middleware
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "net/http"
8 "net/http/httptest"
9 "testing"
10
11 "github.com/gin-gonic/gin"
12 "github.com/stretchr/testify/assert"
13
14 "edge-infra.dev/pkg/lib/fog"
15 "edge-infra.dev/pkg/sds/emergencyaccess/types"
16 )
17
18 func TestSaveAuth(t *testing.T) {
19 t.Parallel()
20
21 tests := map[string]struct {
22 addReqHeaders func(*http.Request)
23 user types.User
24 ok bool
25 }{
26 "Comma separated headers": {
27
28 addReqHeaders: func(r *http.Request) {
29 r.Header.Add("X-Auth-Username", "abcd")
30 r.Header.Set("X-Auth-Email", "abcd@efgh.xyz")
31 r.Header.Set("X-Auth-Roles", "role1,role2, role3")
32 r.Header.Set("X-Auth-Banners", "banner1,banner2")
33 },
34 user: types.User{
35 Username: "abcd",
36 Email: "abcd@efgh.xyz",
37 Roles: []string{"role1,role2, role3"},
38 Banners: []string{"banner1,banner2"},
39 },
40 ok: true,
41 },
42 "Multiple headers": {
43
44
45
46 addReqHeaders: func(r *http.Request) {
47 r.Header.Add("X-Auth-Username", "abcd")
48 r.Header.Add("X-Auth-Username", "efgh")
49 r.Header.Set("X-Auth-Email", "abcd@efgh.xyz")
50 r.Header.Add("X-Auth-Email", "efgh@ijkl.xyz")
51 r.Header.Set("X-Auth-Roles", "role1")
52 r.Header.Add("X-Auth-Roles", "role2")
53 r.Header.Add("X-Auth-Banners", "banner1")
54 r.Header.Add("X-Auth-Banners", "banner2")
55 },
56 user: types.User{
57 Username: "abcd",
58 Email: "abcd@efgh.xyz",
59 Roles: []string{"role1", "role2"},
60 Banners: []string{"banner1", "banner2"},
61 },
62 ok: true,
63 },
64 "Missing Headers": {
65
66 addReqHeaders: func(_ *http.Request) {},
67 user: types.User{},
68 ok: true,
69 },
70 "Only Some Headers": {
71
72
73
74 addReqHeaders: func(r *http.Request) {
75 r.Header.Add("X-Auth-Username", "abcd")
76 r.Header.Set("X-Auth-Roles", "role1")
77 },
78 user: types.User{
79 Username: "abcd",
80 Email: "",
81 Roles: []string{"role1"},
82 Banners: nil,
83 },
84 ok: true,
85 },
86 }
87
88 for name, tc := range tests {
89 tc := tc
90 t.Run(name, func(t *testing.T) {
91 t.Parallel()
92
93 r := httptest.NewRecorder()
94 req := httptest.NewRequest(http.MethodGet, "/", nil)
95
96 tc.addReqHeaders(req)
97
98 _, router := gin.CreateTestContext(r)
99 router.ContextWithFallback = true
100 router.Use(SaveAuthToContext())
101
102 var user types.User
103 var ok bool
104 router.Any("/", func(ctx *gin.Context) {
105 user, ok = types.UserFromContext(ctx)
106 })
107
108 router.ServeHTTP(r, req)
109
110 assert.Equal(t, tc.user, user)
111 assert.Equal(t, tc.ok, ok)
112 })
113 }
114 }
115
116 func TestVerifyUserDetailsInContext(t *testing.T) {
117 t.Parallel()
118
119 r := httptest.NewRecorder()
120 _, router := gin.CreateTestContext(r)
121 router.ContextWithFallback = true
122 router.Use(VerifyUserDetailsInContext())
123
124 user := types.User{
125 Username: "username",
126 Email: "email@domain",
127 }
128 ctx := types.UserIntoContext(context.Background(), user)
129 req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
130 assert.NoError(t, err)
131
132 router.Any("/", func(ctx *gin.Context) {
133 ctx.Status(http.StatusOK)
134 })
135 router.ServeHTTP(r, req)
136
137 assert.Equal(t, http.StatusOK, r.Result().StatusCode)
138 }
139
140 func TestVerifyUserDetailsInContextFail(t *testing.T) {
141 t.Parallel()
142
143 tests := map[string]struct {
144 user types.User
145 expMsg string
146 }{
147 "No username": {
148 user: types.User{
149 Email: "email@domain",
150 },
151 expMsg: "60001: User Authorization Failure - User not found. Error: username not found in authservice request",
152 },
153 "No email": {
154 user: types.User{
155 Username: "username",
156 },
157 expMsg: "60001: User Authorization Failure - User not found. Error: email not found in authservice request",
158 },
159 "Neither username nor email": {
160 expMsg: "60001: User Authorization Failure - User not found. Error: username not found in authservice request\nemail not found in authservice request",
161 },
162 }
163
164 for name, tc := range tests {
165 tc := tc
166 t.Run(name, func(t *testing.T) {
167 t.Parallel()
168
169 r := httptest.NewRecorder()
170 _, router := gin.CreateTestContext(r)
171 router.ContextWithFallback = true
172 router.Use(VerifyUserDetailsInContext())
173
174 buf := bytes.Buffer{}
175 ctx := fog.IntoContext(context.Background(), fog.New(fog.To(&buf)))
176 ctx = types.UserIntoContext(ctx, tc.user)
177 req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
178 assert.NoError(t, err)
179
180 router.Any("/", func(ctx *gin.Context) {
181 ctx.Status(http.StatusOK)
182 })
183
184 router.ServeHTTP(r, req)
185
186 m := map[string]interface{}{}
187 err = json.Unmarshal(buf.Bytes(), &m)
188 assert.NoError(t, err)
189
190 assert.Equal(t, http.StatusForbidden, r.Result().StatusCode)
191 assert.Equal(t, tc.expMsg, m["error"])
192 })
193 }
194 }
195
View as plain text