package authserver import ( "fmt" "net/http" "net/http/httptest" "os" "strings" "testing" "time" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/api/testutils/seededpostgres" "edge-infra.dev/pkg/edge/auth-proxy/session" "edge-infra.dev/pkg/lib/fog" vncconst "edge-infra.dev/pkg/sds/vnc/constants" ) var ( dbName = "testauthserver" username = "test-user" password = "test123" seededPostgres *seededpostgres.SeededPostgres args []string ) func TestMain(m *testing.M) { seededpg, err := seededpostgres.NewWithUser(dbName, username, password) if err != nil { _ = seededpg.Close() os.Exit(1) } seededPostgres = seededpg args = []string{ fmt.Sprintf("--gin-mode=%s", gin.TestMode), "--database-host=localhost", fmt.Sprintf("--database-name=%s", dbName), fmt.Sprintf("--database-username=%s", username), fmt.Sprintf("--database-password=%s", password), fmt.Sprintf("--database-port=%v", seededPostgres.Port()), "--session-secret=test", } m.Run() } func TestHandleNotAuthenticated(t *testing.T) { rtr := gin.Default() r := httptest.NewRecorder() _, ginTestEngine := getTestGinContext(r) as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) rtr.GET("/ping", func(c *gin.Context) { as.handleNotAuthenticated(c, time.Now()) }) req, _ := http.NewRequest("GET", "/ping", nil) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusUnauthorized, r.Code) } func TestHandleEnforceRoleAccess_EDGE_BANNER_ADMIN(t *testing.T) { rtr := gin.Default() r := httptest.NewRecorder() ctx, ginTestEngine := getTestGinContext(r) ctx.Request = &http.Request{ RequestURI: "/remoteaccess/", } as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) rtr.GET("/ping", func(_ *gin.Context) { resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_ADMIN"}) assert.True(t, resp) }) req, _ := http.NewRequest("GET", "/ping", nil) rtr.ServeHTTP(r, req) } func TestHandleEnforceRoleAccess_EDGE_BANNER_OPERATOR(t *testing.T) { rtr := gin.Default() r := httptest.NewRecorder() ctx, ginTestEngine := getTestGinContext(r) ctx.Request = &http.Request{ RequestURI: "/remoteaccess/", } as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) rtr.GET("/ping", func(_ *gin.Context) { resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_OPERATOR"}) assert.True(t, resp) }) req, _ := http.NewRequest("GET", "/ping", nil) rtr.ServeHTTP(r, req) } func TestHandleEnforceRoleAccess_EDGE_BANNER_VIEWER(t *testing.T) { rtr := gin.Default() r := httptest.NewRecorder() ctx, ginTestEngine := getTestGinContext(r) ctx.Request = &http.Request{ RequestURI: "/remoteaccess/", } as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) rtr.GET("/ping", func(_ *gin.Context) { resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_VIEWER"}) assert.True(t, resp) }) req, err := http.NewRequest("GET", "/ping", nil) assert.NoError(t, err) rtr.ServeHTTP(r, req) } func TestHandleNoExpiration(t *testing.T) { rtr := gin.Default() r := httptest.NewRecorder() _, ginTestEngine := getTestGinContext(r) as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) rtr.GET("/ping", func(c *gin.Context) { as.handleNoExpiration(c, time.Now()) }) req, _ := http.NewRequest("GET", "/ping", nil) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusInternalServerError, r.Code) } func TestHandleExpiration(t *testing.T) { rtr := gin.Default() r := httptest.NewRecorder() _, ginTestEngine := getTestGinContext(r) as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) rtr.GET("/ping", func(c *gin.Context) { as.handleExpiration(c, time.Now()) }) req, _ := http.NewRequest("GET", "/ping", nil) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusUnauthorized, r.Code) } func TestHandleValidSessionCaseOne(t *testing.T) { rtr := gin.Default() r := httptest.NewRecorder() _, ginTestEngine := getTestGinContext(r) as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) mockSessions := session.NewMockSessions() rtr.GET("/ping", func(c *gin.Context) { as.handleValidSession(c, time.Now(), mockSessions) }) req, _ := http.NewRequest("GET", "/ping", nil) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusUnauthorized, r.Code) } func TestHandleValidSessionCaseTwoOrgAdmin(t *testing.T) { rtr := gin.Default() r := httptest.NewRecorder() ctx, ginTestEngine := getTestGinContext(r) ctx.Request = &http.Request{ RequestURI: "/remoteaccess/", } as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) mockSessions := session.NewMockSessions() mockSessions.Set("username", "test") rtr.GET("/ping", func(c *gin.Context) { as.handleValidSession(c, time.Now(), mockSessions) }) mockSessions.Set("organization", "test-org") mockSessions.Set("roles", []string{"EDGE_ORG_ADMIN"}) assert.NoError(t, mockSessions.Save()) req, _ := http.NewRequest("GET", "/ping", nil) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusOK, r.Code) } func TestHandleValidSession_EDGE_BANNER_ADMIN(t *testing.T) { //nolint rtr := gin.Default() r := httptest.NewRecorder() ctx, ginTestEngine := getTestGinContext(r) ctx.Request = &http.Request{ RequestURI: "/remoteaccess", } as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) mockSessions := session.NewMockSessions() rtr.GET("/ping", func(c *gin.Context) { as.handleValidSession(c, time.Now(), mockSessions) }) bannerName := uuid.NewString() mockSessions.Set("organization", bannerName) mockSessions.Set("roles", []string{"EDGE_BANNER_ADMIN"}) mockSessions.Set("username", "test") assert.NoError(t, mockSessions.Save()) tenantOrgID := uuid.NewString() _, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName) assert.NoError(t, err) tenanantEdgeID := "" row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID) assert.NoError(t, row.Scan(&tenanantEdgeID)) bannerEdgeID := uuid.NewString() projectID := uuid.NewString() _, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner") assert.NoError(t, err) req, _ := http.NewRequest("GET", "/ping", nil) req.Header.Set(bannerHeaderName, bannerEdgeID) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusOK, r.Code) } func TestHandleValidSession_EDGE_BANNER_VIEWER(t *testing.T) { //nolint rtr := gin.Default() r := httptest.NewRecorder() _, ginTestEngine := getTestGinContext(r) as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) mockSessions := session.NewMockSessions() rtr.GET("/ping", func(c *gin.Context) { as.handleValidSession(c, time.Now(), mockSessions) }) bannerName := uuid.NewString() mockSessions.Set("organization", bannerName) mockSessions.Set("roles", []string{"EDGE_BANNER_VIEWER"}) mockSessions.Set("username", "test") assert.NoError(t, mockSessions.Save()) tenantOrgID := uuid.NewString() projectID := uuid.NewString() _, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName) assert.NoError(t, err) tenanantEdgeID := "" row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID) assert.NoError(t, row.Scan(&tenanantEdgeID)) bannerEdgeID := uuid.NewString() _, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner") assert.NoError(t, err) req, _ := http.NewRequest("GET", "/ping", nil) req.Header.Set(bannerHeaderName, bannerEdgeID) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusOK, r.Code) } func TestHandleValidSession_EDGE_BANNER_VIEWER_grafana(t *testing.T) { //nolint rtr := gin.Default() r := httptest.NewRecorder() ctx, ginTestEngine := getTestGinContext(r) ctx.Request = &http.Request{ RequestURI: "/grafana/", } as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) mockSessions := session.NewMockSessions() rtr.GET("/grafana/", func(c *gin.Context) { as.handleValidSession(c, time.Now(), mockSessions) }) bannerName := uuid.NewString() mockSessions.Set("organization", bannerName) mockSessions.Set("roles", []string{"EDGE_BANNER_VIEWER"}) mockSessions.Set("username", "test") assert.NoError(t, mockSessions.Save()) tenantOrgID := uuid.NewString() projectID := uuid.NewString() _, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName) assert.NoError(t, err) tenanantEdgeID := "" row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID) assert.NoError(t, row.Scan(&tenanantEdgeID)) bannerEdgeID := uuid.NewString() _, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner") assert.NoError(t, err) req, _ := http.NewRequest("GET", "/grafana/", nil) req.Header.Set(bannerHeaderName, bannerEdgeID) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusOK, r.Code) } func TestHandleValidSession_VNC(t *testing.T) { tests := map[string]struct { url string role string requestHeaderPresent bool }{ "/novnc/": { url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/", role: "EDGE_BANNER_ADMIN", requestHeaderPresent: false, }, "/novnc/ws": { url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/write/ws", role: "EDGE_BANNER_ADMIN", requestHeaderPresent: false, }, "/notnovnc/": { url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/notnovnc/", role: "EDGE_BANNER_ADMIN", requestHeaderPresent: false, }, "/novnc/authorize with EDGE_BANNER_ADMIN role": { url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/read/authorize", role: "EDGE_BANNER_ADMIN", requestHeaderPresent: true, }, "/novnc/authorize with EDGE_ORG_ADMIN role": { url: "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/write/authorize", role: "EDGE_ORG_ADMIN", requestHeaderPresent: true, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { // Setup rtr := gin.Default() r := httptest.NewRecorder() ctx, ginTestEngine := getTestGinContext(r) ctx.Request = &http.Request{ RequestURI: tc.url, } as, err := NewAuthServer(args, ginTestEngine) assert.NoError(t, err) mockSessions := session.NewMockSessions() rtr.GET(strings.Split(tc.url, "?")[0], func(c *gin.Context) { as.handleValidSession(c, time.Now(), mockSessions) }) bannerName := uuid.NewString() mockSessions.Set("organization", bannerName) mockSessions.Set("roles", []string{tc.role}) username := "test" mockSessions.Set("username", username) assert.NoError(t, mockSessions.Save()) tenantOrgID := uuid.NewString() projectID := uuid.NewString() _, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName) assert.NoError(t, err) tenanantEdgeID := "" row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID) assert.NoError(t, row.Scan(&tenanantEdgeID)) bannerEdgeID := uuid.NewString() _, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner") assert.NoError(t, err) // Test req, err := http.NewRequest("GET", tc.url, nil) assert.NoError(t, err) req.Header.Set(bannerHeaderName, bannerEdgeID) rtr.ServeHTTP(r, req) assert.Equal(t, http.StatusOK, r.Code) resultHeader := r.Result().Header assert.Equal(t, username, resultHeader.Get("X-WEBAUTH-USER")) requestID := resultHeader.Get(vncconst.HeaderKeyRequestID) if tc.requestHeaderPresent { _, err = uuid.Parse(requestID) assert.NoError(t, err) } else { assert.Empty(t, requestID) } }) } } func getTestGinContext(r *httptest.ResponseRecorder) (*gin.Context, *gin.Engine) { gin.SetMode(gin.TestMode) ctx, ginEngine := gin.CreateTestContext(r) return ctx, ginEngine } func TestCheckFilterFunc(t *testing.T) { tests := map[string]struct { checks []check statusCode int }{ "No error": { checks: []check{ { pathFilter: "/request", checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error { return nil }, }, }, statusCode: http.StatusOK, }, "Error": { checks: []check{ { pathFilter: "/request", checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error { return fmt.Errorf("error") }, }, }, statusCode: 500, }, "Unauthorized": { checks: []check{ { pathFilter: "/request", checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error { return &httpError{ statusCode: http.StatusUnauthorized, err: fmt.Errorf("error"), } }, }, }, statusCode: 401, }, "Incorrect Path": { checks: []check{ { pathFilter: "/otherpath", checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error { return fmt.Errorf("error") }, }, }, statusCode: http.StatusOK, }, "Partial Path Match": { checks: []check{ { pathFilter: "/req", checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error { return fmt.Errorf("error") }, }, }, statusCode: http.StatusInternalServerError, }, "Match All Paths": { checks: []check{ { pathFilter: "", checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error { return fmt.Errorf("error") }, }, }, statusCode: http.StatusInternalServerError, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { // Setup rtr := gin.Default() r := httptest.NewRecorder() _, ginTestEngine := getTestGinContext(r) db, err := seededPostgres.DB() assert.NoError(t, err) as := AuthServer{ GinMode: gin.TestMode, GinEngine: ginTestEngine, db: db, Log: fog.New(), checks: tc.checks, } mockSessions := session.NewMockSessions() bannerName := uuid.NewString() mockSessions.Set("organization", bannerName) mockSessions.Set("roles", []string{"EDGE_ORG_ADMIN"}) username := "test" mockSessions.Set("username", username) assert.NoError(t, mockSessions.Save()) rtr.GET("/request", func(c *gin.Context) { as.handleValidSession(c, time.Now(), mockSessions) }) bannerEdgeID := uuid.NewString() // Test req, _ := http.NewRequest("GET", "/request", nil) req.Header.Set(bannerHeaderName, bannerEdgeID) rtr.ServeHTTP(r, req) assert.Equal(t, tc.statusCode, r.Code) }) } }