1 package authserver
2
3 import (
4 "fmt"
5 "net/http"
6 "net/http/httptest"
7 "os"
8 "strings"
9 "testing"
10 "time"
11
12 "github.com/gin-contrib/sessions"
13 "github.com/gin-gonic/gin"
14 "github.com/google/uuid"
15 "github.com/stretchr/testify/assert"
16
17 "edge-infra.dev/pkg/edge/api/sql"
18 "edge-infra.dev/pkg/edge/api/testutils/seededpostgres"
19 "edge-infra.dev/pkg/edge/auth-proxy/session"
20 "edge-infra.dev/pkg/lib/fog"
21 vncconst "edge-infra.dev/pkg/sds/vnc/constants"
22 )
23
24 var (
25 dbName = "testauthserver"
26 username = "test-user"
27 password = "test123"
28 seededPostgres *seededpostgres.SeededPostgres
29 args []string
30 )
31
32 func TestMain(m *testing.M) {
33 seededpg, err := seededpostgres.NewWithUser(dbName, username, password)
34 if err != nil {
35 _ = seededpg.Close()
36 os.Exit(1)
37 }
38 seededPostgres = seededpg
39 args = []string{
40 fmt.Sprintf("--gin-mode=%s", gin.TestMode),
41 "--database-host=localhost",
42 fmt.Sprintf("--database-name=%s", dbName),
43 fmt.Sprintf("--database-username=%s", username),
44 fmt.Sprintf("--database-password=%s", password),
45 fmt.Sprintf("--database-port=%v", seededPostgres.Port()),
46 "--session-secret=test",
47 }
48 m.Run()
49 }
50
51 func TestHandleNotAuthenticated(t *testing.T) {
52 rtr := gin.Default()
53 r := httptest.NewRecorder()
54 _, ginTestEngine := getTestGinContext(r)
55 as, err := NewAuthServer(args, ginTestEngine)
56 assert.NoError(t, err)
57 rtr.GET("/ping", func(c *gin.Context) {
58 as.handleNotAuthenticated(c, time.Now())
59 })
60 req, _ := http.NewRequest("GET", "/ping", nil)
61 rtr.ServeHTTP(r, req)
62 assert.Equal(t, http.StatusUnauthorized, r.Code)
63 }
64
65 func TestHandleEnforceRoleAccess_EDGE_BANNER_ADMIN(t *testing.T) {
66 rtr := gin.Default()
67 r := httptest.NewRecorder()
68 ctx, ginTestEngine := getTestGinContext(r)
69 ctx.Request = &http.Request{
70 RequestURI: "/remoteaccess/",
71 }
72 as, err := NewAuthServer(args, ginTestEngine)
73 assert.NoError(t, err)
74 rtr.GET("/ping", func(_ *gin.Context) {
75 resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_ADMIN"})
76 assert.True(t, resp)
77 })
78 req, _ := http.NewRequest("GET", "/ping", nil)
79 rtr.ServeHTTP(r, req)
80 }
81
82 func TestHandleEnforceRoleAccess_EDGE_BANNER_OPERATOR(t *testing.T) {
83 rtr := gin.Default()
84 r := httptest.NewRecorder()
85 ctx, ginTestEngine := getTestGinContext(r)
86 ctx.Request = &http.Request{
87 RequestURI: "/remoteaccess/",
88 }
89 as, err := NewAuthServer(args, ginTestEngine)
90 assert.NoError(t, err)
91 rtr.GET("/ping", func(_ *gin.Context) {
92 resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_OPERATOR"})
93 assert.True(t, resp)
94 })
95 req, _ := http.NewRequest("GET", "/ping", nil)
96 rtr.ServeHTTP(r, req)
97 }
98
99 func TestHandleEnforceRoleAccess_EDGE_BANNER_VIEWER(t *testing.T) {
100 rtr := gin.Default()
101 r := httptest.NewRecorder()
102 ctx, ginTestEngine := getTestGinContext(r)
103 ctx.Request = &http.Request{
104 RequestURI: "/remoteaccess/",
105 }
106 as, err := NewAuthServer(args, ginTestEngine)
107 assert.NoError(t, err)
108 rtr.GET("/ping", func(_ *gin.Context) {
109 resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_VIEWER"})
110 assert.True(t, resp)
111 })
112 req, err := http.NewRequest("GET", "/ping", nil)
113 assert.NoError(t, err)
114 rtr.ServeHTTP(r, req)
115 }
116
117 func TestHandleNoExpiration(t *testing.T) {
118 rtr := gin.Default()
119 r := httptest.NewRecorder()
120 _, ginTestEngine := getTestGinContext(r)
121 as, err := NewAuthServer(args, ginTestEngine)
122 assert.NoError(t, err)
123 rtr.GET("/ping", func(c *gin.Context) {
124 as.handleNoExpiration(c, time.Now())
125 })
126 req, _ := http.NewRequest("GET", "/ping", nil)
127 rtr.ServeHTTP(r, req)
128 assert.Equal(t, http.StatusInternalServerError, r.Code)
129 }
130
131 func TestHandleExpiration(t *testing.T) {
132 rtr := gin.Default()
133 r := httptest.NewRecorder()
134 _, ginTestEngine := getTestGinContext(r)
135 as, err := NewAuthServer(args, ginTestEngine)
136 assert.NoError(t, err)
137 rtr.GET("/ping", func(c *gin.Context) {
138 as.handleExpiration(c, time.Now())
139 })
140 req, _ := http.NewRequest("GET", "/ping", nil)
141 rtr.ServeHTTP(r, req)
142 assert.Equal(t, http.StatusUnauthorized, r.Code)
143 }
144
145 func TestHandleValidSessionCaseOne(t *testing.T) {
146 rtr := gin.Default()
147 r := httptest.NewRecorder()
148 _, ginTestEngine := getTestGinContext(r)
149 as, err := NewAuthServer(args, ginTestEngine)
150 assert.NoError(t, err)
151 mockSessions := session.NewMockSessions()
152 rtr.GET("/ping", func(c *gin.Context) {
153 as.handleValidSession(c, time.Now(), mockSessions)
154 })
155 req, _ := http.NewRequest("GET", "/ping", nil)
156 rtr.ServeHTTP(r, req)
157 assert.Equal(t, http.StatusUnauthorized, r.Code)
158 }
159
160 func TestHandleValidSessionCaseTwoOrgAdmin(t *testing.T) {
161 rtr := gin.Default()
162 r := httptest.NewRecorder()
163 ctx, ginTestEngine := getTestGinContext(r)
164 ctx.Request = &http.Request{
165 RequestURI: "/remoteaccess/",
166 }
167 as, err := NewAuthServer(args, ginTestEngine)
168 assert.NoError(t, err)
169 mockSessions := session.NewMockSessions()
170 mockSessions.Set("username", "test")
171 rtr.GET("/ping", func(c *gin.Context) {
172 as.handleValidSession(c, time.Now(), mockSessions)
173 })
174 mockSessions.Set("organization", "test-org")
175 mockSessions.Set("roles", []string{"EDGE_ORG_ADMIN"})
176 assert.NoError(t, mockSessions.Save())
177 req, _ := http.NewRequest("GET", "/ping", nil)
178 rtr.ServeHTTP(r, req)
179 assert.Equal(t, http.StatusOK, r.Code)
180 }
181
182 func TestHandleValidSession_EDGE_BANNER_ADMIN(t *testing.T) {
183 rtr := gin.Default()
184 r := httptest.NewRecorder()
185 ctx, ginTestEngine := getTestGinContext(r)
186 ctx.Request = &http.Request{
187 RequestURI: "/remoteaccess",
188 }
189 as, err := NewAuthServer(args, ginTestEngine)
190 assert.NoError(t, err)
191 mockSessions := session.NewMockSessions()
192 rtr.GET("/ping", func(c *gin.Context) {
193 as.handleValidSession(c, time.Now(), mockSessions)
194 })
195 bannerName := uuid.NewString()
196 mockSessions.Set("organization", bannerName)
197 mockSessions.Set("roles", []string{"EDGE_BANNER_ADMIN"})
198 mockSessions.Set("username", "test")
199 assert.NoError(t, mockSessions.Save())
200 tenantOrgID := uuid.NewString()
201 _, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName)
202 assert.NoError(t, err)
203 tenanantEdgeID := ""
204 row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID)
205 assert.NoError(t, row.Scan(&tenanantEdgeID))
206 bannerEdgeID := uuid.NewString()
207 projectID := uuid.NewString()
208 _, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner")
209 assert.NoError(t, err)
210 req, _ := http.NewRequest("GET", "/ping", nil)
211 req.Header.Set(bannerHeaderName, bannerEdgeID)
212 rtr.ServeHTTP(r, req)
213 assert.Equal(t, http.StatusOK, r.Code)
214 }
215
216 func TestHandleValidSession_EDGE_BANNER_VIEWER(t *testing.T) {
217 rtr := gin.Default()
218 r := httptest.NewRecorder()
219 _, ginTestEngine := getTestGinContext(r)
220 as, err := NewAuthServer(args, ginTestEngine)
221 assert.NoError(t, err)
222 mockSessions := session.NewMockSessions()
223 rtr.GET("/ping", func(c *gin.Context) {
224 as.handleValidSession(c, time.Now(), mockSessions)
225 })
226 bannerName := uuid.NewString()
227 mockSessions.Set("organization", bannerName)
228 mockSessions.Set("roles", []string{"EDGE_BANNER_VIEWER"})
229 mockSessions.Set("username", "test")
230 assert.NoError(t, mockSessions.Save())
231 tenantOrgID := uuid.NewString()
232 projectID := uuid.NewString()
233 _, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName)
234 assert.NoError(t, err)
235 tenanantEdgeID := ""
236 row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID)
237 assert.NoError(t, row.Scan(&tenanantEdgeID))
238 bannerEdgeID := uuid.NewString()
239 _, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner")
240 assert.NoError(t, err)
241 req, _ := http.NewRequest("GET", "/ping", nil)
242 req.Header.Set(bannerHeaderName, bannerEdgeID)
243 rtr.ServeHTTP(r, req)
244 assert.Equal(t, http.StatusOK, r.Code)
245 }
246
247 func TestHandleValidSession_EDGE_BANNER_VIEWER_grafana(t *testing.T) {
248 rtr := gin.Default()
249 r := httptest.NewRecorder()
250 ctx, ginTestEngine := getTestGinContext(r)
251 ctx.Request = &http.Request{
252 RequestURI: "/grafana/",
253 }
254 as, err := NewAuthServer(args, ginTestEngine)
255 assert.NoError(t, err)
256 mockSessions := session.NewMockSessions()
257 rtr.GET("/grafana/", func(c *gin.Context) {
258 as.handleValidSession(c, time.Now(), mockSessions)
259 })
260 bannerName := uuid.NewString()
261 mockSessions.Set("organization", bannerName)
262 mockSessions.Set("roles", []string{"EDGE_BANNER_VIEWER"})
263 mockSessions.Set("username", "test")
264 assert.NoError(t, mockSessions.Save())
265 tenantOrgID := uuid.NewString()
266 projectID := uuid.NewString()
267 _, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName)
268 assert.NoError(t, err)
269 tenanantEdgeID := ""
270 row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID)
271 assert.NoError(t, row.Scan(&tenanantEdgeID))
272 bannerEdgeID := uuid.NewString()
273 _, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner")
274 assert.NoError(t, err)
275 req, _ := http.NewRequest("GET", "/grafana/", nil)
276 req.Header.Set(bannerHeaderName, bannerEdgeID)
277 rtr.ServeHTTP(r, req)
278 assert.Equal(t, http.StatusOK, r.Code)
279 }
280
281 func TestHandleValidSession_VNC(t *testing.T) {
282 tests := map[string]struct {
283 url string
284 role string
285 requestHeaderPresent bool
286 }{
287 "/novnc/": {
288 url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/",
289 role: "EDGE_BANNER_ADMIN",
290 requestHeaderPresent: false,
291 },
292 "/novnc/ws": {
293 url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/write/ws",
294 role: "EDGE_BANNER_ADMIN",
295 requestHeaderPresent: false,
296 },
297 "/notnovnc/": {
298 url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/notnovnc/",
299 role: "EDGE_BANNER_ADMIN",
300 requestHeaderPresent: false,
301 },
302 "/novnc/authorize with EDGE_BANNER_ADMIN role": {
303 url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/read/authorize",
304 role: "EDGE_BANNER_ADMIN",
305 requestHeaderPresent: true,
306 },
307 "/novnc/authorize with EDGE_ORG_ADMIN role": {
308 url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/write/authorize",
309 role: "EDGE_ORG_ADMIN",
310 requestHeaderPresent: true,
311 },
312 }
313
314 for name, tc := range tests {
315 tc := tc
316 t.Run(name, func(t *testing.T) {
317
318 rtr := gin.Default()
319 r := httptest.NewRecorder()
320 ctx, ginTestEngine := getTestGinContext(r)
321 ctx.Request = &http.Request{
322 RequestURI: tc.url,
323 }
324 as, err := NewAuthServer(args, ginTestEngine)
325 assert.NoError(t, err)
326 mockSessions := session.NewMockSessions()
327 rtr.GET(strings.Split(tc.url, "?")[0], func(c *gin.Context) {
328 as.handleValidSession(c, time.Now(), mockSessions)
329 })
330 bannerName := uuid.NewString()
331 mockSessions.Set("organization", bannerName)
332 mockSessions.Set("roles", []string{tc.role})
333 username := "test"
334 mockSessions.Set("username", username)
335 assert.NoError(t, mockSessions.Save())
336 tenantOrgID := uuid.NewString()
337 projectID := uuid.NewString()
338 _, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName)
339 assert.NoError(t, err)
340 tenanantEdgeID := ""
341 row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID)
342 assert.NoError(t, row.Scan(&tenanantEdgeID))
343 bannerEdgeID := uuid.NewString()
344 _, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner")
345 assert.NoError(t, err)
346
347
348 req, err := http.NewRequest("GET", tc.url, nil)
349 assert.NoError(t, err)
350 req.Header.Set(bannerHeaderName, bannerEdgeID)
351 rtr.ServeHTTP(r, req)
352 assert.Equal(t, http.StatusOK, r.Code)
353
354 resultHeader := r.Result().Header
355 assert.Equal(t, username, resultHeader.Get("X-WEBAUTH-USER"))
356
357 requestID := resultHeader.Get(vncconst.HeaderKeyRequestID)
358 if tc.requestHeaderPresent {
359 _, err = uuid.Parse(requestID)
360 assert.NoError(t, err)
361 } else {
362 assert.Empty(t, requestID)
363 }
364 })
365 }
366 }
367
368 func getTestGinContext(r *httptest.ResponseRecorder) (*gin.Context, *gin.Engine) {
369 gin.SetMode(gin.TestMode)
370 ctx, ginEngine := gin.CreateTestContext(r)
371 return ctx, ginEngine
372 }
373
374 func TestCheckFilterFunc(t *testing.T) {
375 tests := map[string]struct {
376 checks []check
377 statusCode int
378 }{
379 "No error": {
380 checks: []check{
381 {
382 pathFilter: "/request",
383 checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
384 return nil
385 },
386 },
387 },
388 statusCode: http.StatusOK,
389 },
390 "Error": {
391 checks: []check{
392 {
393 pathFilter: "/request",
394 checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
395 return fmt.Errorf("error")
396 },
397 },
398 },
399 statusCode: 500,
400 },
401 "Unauthorized": {
402 checks: []check{
403 {
404 pathFilter: "/request",
405 checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
406 return &httpError{
407 statusCode: http.StatusUnauthorized,
408 err: fmt.Errorf("error"),
409 }
410 },
411 },
412 },
413 statusCode: 401,
414 },
415 "Incorrect Path": {
416 checks: []check{
417 {
418 pathFilter: "/otherpath",
419 checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
420 return fmt.Errorf("error")
421 },
422 },
423 },
424 statusCode: http.StatusOK,
425 },
426 "Partial Path Match": {
427 checks: []check{
428 {
429 pathFilter: "/req",
430 checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
431 return fmt.Errorf("error")
432 },
433 },
434 },
435 statusCode: http.StatusInternalServerError,
436 },
437 "Match All Paths": {
438 checks: []check{
439 {
440 pathFilter: "",
441 checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
442 return fmt.Errorf("error")
443 },
444 },
445 },
446 statusCode: http.StatusInternalServerError,
447 },
448 }
449
450 for name, tc := range tests {
451 tc := tc
452 t.Run(name, func(t *testing.T) {
453
454 rtr := gin.Default()
455 r := httptest.NewRecorder()
456 _, ginTestEngine := getTestGinContext(r)
457
458 db, err := seededPostgres.DB()
459 assert.NoError(t, err)
460
461 as := AuthServer{
462 GinMode: gin.TestMode,
463 GinEngine: ginTestEngine,
464 db: db,
465 Log: fog.New(),
466 checks: tc.checks,
467 }
468
469 mockSessions := session.NewMockSessions()
470 bannerName := uuid.NewString()
471 mockSessions.Set("organization", bannerName)
472 mockSessions.Set("roles", []string{"EDGE_ORG_ADMIN"})
473 username := "test"
474 mockSessions.Set("username", username)
475 assert.NoError(t, mockSessions.Save())
476
477 rtr.GET("/request", func(c *gin.Context) {
478 as.handleValidSession(c, time.Now(), mockSessions)
479 })
480
481 bannerEdgeID := uuid.NewString()
482
483
484 req, _ := http.NewRequest("GET", "/request", nil)
485 req.Header.Set(bannerHeaderName, bannerEdgeID)
486 rtr.ServeHTTP(r, req)
487 assert.Equal(t, tc.statusCode, r.Code)
488 })
489 }
490 }
491
View as plain text