...

Source file src/edge-infra.dev/pkg/sds/remoteaccess/authserver/handlers_test.go

Documentation: edge-infra.dev/pkg/sds/remoteaccess/authserver

     1  package authserver
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"os"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/gin-contrib/sessions"
    13  	"github.com/gin-gonic/gin"
    14  	"github.com/google/uuid"
    15  	"github.com/stretchr/testify/assert"
    16  
    17  	"edge-infra.dev/pkg/edge/api/sql"
    18  	"edge-infra.dev/pkg/edge/api/testutils/seededpostgres"
    19  	"edge-infra.dev/pkg/edge/auth-proxy/session"
    20  	"edge-infra.dev/pkg/lib/fog"
    21  	vncconst "edge-infra.dev/pkg/sds/vnc/constants"
    22  )
    23  
    24  var (
    25  	dbName         = "testauthserver"
    26  	username       = "test-user"
    27  	password       = "test123"
    28  	seededPostgres *seededpostgres.SeededPostgres
    29  	args           []string
    30  )
    31  
    32  func TestMain(m *testing.M) {
    33  	seededpg, err := seededpostgres.NewWithUser(dbName, username, password)
    34  	if err != nil {
    35  		_ = seededpg.Close()
    36  		os.Exit(1)
    37  	}
    38  	seededPostgres = seededpg
    39  	args = []string{
    40  		fmt.Sprintf("--gin-mode=%s", gin.TestMode),
    41  		"--database-host=localhost",
    42  		fmt.Sprintf("--database-name=%s", dbName),
    43  		fmt.Sprintf("--database-username=%s", username),
    44  		fmt.Sprintf("--database-password=%s", password),
    45  		fmt.Sprintf("--database-port=%v", seededPostgres.Port()),
    46  		"--session-secret=test",
    47  	}
    48  	m.Run()
    49  }
    50  
    51  func TestHandleNotAuthenticated(t *testing.T) {
    52  	rtr := gin.Default()
    53  	r := httptest.NewRecorder()
    54  	_, ginTestEngine := getTestGinContext(r)
    55  	as, err := NewAuthServer(args, ginTestEngine)
    56  	assert.NoError(t, err)
    57  	rtr.GET("/ping", func(c *gin.Context) {
    58  		as.handleNotAuthenticated(c, time.Now())
    59  	})
    60  	req, _ := http.NewRequest("GET", "/ping", nil)
    61  	rtr.ServeHTTP(r, req)
    62  	assert.Equal(t, http.StatusUnauthorized, r.Code)
    63  }
    64  
    65  func TestHandleEnforceRoleAccess_EDGE_BANNER_ADMIN(t *testing.T) {
    66  	rtr := gin.Default()
    67  	r := httptest.NewRecorder()
    68  	ctx, ginTestEngine := getTestGinContext(r)
    69  	ctx.Request = &http.Request{
    70  		RequestURI: "/remoteaccess/",
    71  	}
    72  	as, err := NewAuthServer(args, ginTestEngine)
    73  	assert.NoError(t, err)
    74  	rtr.GET("/ping", func(_ *gin.Context) {
    75  		resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_ADMIN"})
    76  		assert.True(t, resp)
    77  	})
    78  	req, _ := http.NewRequest("GET", "/ping", nil)
    79  	rtr.ServeHTTP(r, req)
    80  }
    81  
    82  func TestHandleEnforceRoleAccess_EDGE_BANNER_OPERATOR(t *testing.T) {
    83  	rtr := gin.Default()
    84  	r := httptest.NewRecorder()
    85  	ctx, ginTestEngine := getTestGinContext(r)
    86  	ctx.Request = &http.Request{
    87  		RequestURI: "/remoteaccess/",
    88  	}
    89  	as, err := NewAuthServer(args, ginTestEngine)
    90  	assert.NoError(t, err)
    91  	rtr.GET("/ping", func(_ *gin.Context) {
    92  		resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_OPERATOR"})
    93  		assert.True(t, resp)
    94  	})
    95  	req, _ := http.NewRequest("GET", "/ping", nil)
    96  	rtr.ServeHTTP(r, req)
    97  }
    98  
    99  func TestHandleEnforceRoleAccess_EDGE_BANNER_VIEWER(t *testing.T) {
   100  	rtr := gin.Default()
   101  	r := httptest.NewRecorder()
   102  	ctx, ginTestEngine := getTestGinContext(r)
   103  	ctx.Request = &http.Request{
   104  		RequestURI: "/remoteaccess/",
   105  	}
   106  	as, err := NewAuthServer(args, ginTestEngine)
   107  	assert.NoError(t, err)
   108  	rtr.GET("/ping", func(_ *gin.Context) {
   109  		resp := as.enforceRoleAccess(ctx, "test-org", "test-org", []string{"EDGE_BANNER_VIEWER"})
   110  		assert.True(t, resp)
   111  	})
   112  	req, err := http.NewRequest("GET", "/ping", nil)
   113  	assert.NoError(t, err)
   114  	rtr.ServeHTTP(r, req)
   115  }
   116  
   117  func TestHandleNoExpiration(t *testing.T) {
   118  	rtr := gin.Default()
   119  	r := httptest.NewRecorder()
   120  	_, ginTestEngine := getTestGinContext(r)
   121  	as, err := NewAuthServer(args, ginTestEngine)
   122  	assert.NoError(t, err)
   123  	rtr.GET("/ping", func(c *gin.Context) {
   124  		as.handleNoExpiration(c, time.Now())
   125  	})
   126  	req, _ := http.NewRequest("GET", "/ping", nil)
   127  	rtr.ServeHTTP(r, req)
   128  	assert.Equal(t, http.StatusInternalServerError, r.Code)
   129  }
   130  
   131  func TestHandleExpiration(t *testing.T) {
   132  	rtr := gin.Default()
   133  	r := httptest.NewRecorder()
   134  	_, ginTestEngine := getTestGinContext(r)
   135  	as, err := NewAuthServer(args, ginTestEngine)
   136  	assert.NoError(t, err)
   137  	rtr.GET("/ping", func(c *gin.Context) {
   138  		as.handleExpiration(c, time.Now())
   139  	})
   140  	req, _ := http.NewRequest("GET", "/ping", nil)
   141  	rtr.ServeHTTP(r, req)
   142  	assert.Equal(t, http.StatusUnauthorized, r.Code)
   143  }
   144  
   145  func TestHandleValidSessionCaseOne(t *testing.T) {
   146  	rtr := gin.Default()
   147  	r := httptest.NewRecorder()
   148  	_, ginTestEngine := getTestGinContext(r)
   149  	as, err := NewAuthServer(args, ginTestEngine)
   150  	assert.NoError(t, err)
   151  	mockSessions := session.NewMockSessions()
   152  	rtr.GET("/ping", func(c *gin.Context) {
   153  		as.handleValidSession(c, time.Now(), mockSessions)
   154  	})
   155  	req, _ := http.NewRequest("GET", "/ping", nil)
   156  	rtr.ServeHTTP(r, req)
   157  	assert.Equal(t, http.StatusUnauthorized, r.Code)
   158  }
   159  
   160  func TestHandleValidSessionCaseTwoOrgAdmin(t *testing.T) {
   161  	rtr := gin.Default()
   162  	r := httptest.NewRecorder()
   163  	ctx, ginTestEngine := getTestGinContext(r)
   164  	ctx.Request = &http.Request{
   165  		RequestURI: "/remoteaccess/",
   166  	}
   167  	as, err := NewAuthServer(args, ginTestEngine)
   168  	assert.NoError(t, err)
   169  	mockSessions := session.NewMockSessions()
   170  	mockSessions.Set("username", "test")
   171  	rtr.GET("/ping", func(c *gin.Context) {
   172  		as.handleValidSession(c, time.Now(), mockSessions)
   173  	})
   174  	mockSessions.Set("organization", "test-org")
   175  	mockSessions.Set("roles", []string{"EDGE_ORG_ADMIN"})
   176  	assert.NoError(t, mockSessions.Save())
   177  	req, _ := http.NewRequest("GET", "/ping", nil)
   178  	rtr.ServeHTTP(r, req)
   179  	assert.Equal(t, http.StatusOK, r.Code)
   180  }
   181  
   182  func TestHandleValidSession_EDGE_BANNER_ADMIN(t *testing.T) { //nolint
   183  	rtr := gin.Default()
   184  	r := httptest.NewRecorder()
   185  	ctx, ginTestEngine := getTestGinContext(r)
   186  	ctx.Request = &http.Request{
   187  		RequestURI: "/remoteaccess",
   188  	}
   189  	as, err := NewAuthServer(args, ginTestEngine)
   190  	assert.NoError(t, err)
   191  	mockSessions := session.NewMockSessions()
   192  	rtr.GET("/ping", func(c *gin.Context) {
   193  		as.handleValidSession(c, time.Now(), mockSessions)
   194  	})
   195  	bannerName := uuid.NewString()
   196  	mockSessions.Set("organization", bannerName)
   197  	mockSessions.Set("roles", []string{"EDGE_BANNER_ADMIN"})
   198  	mockSessions.Set("username", "test")
   199  	assert.NoError(t, mockSessions.Save())
   200  	tenantOrgID := uuid.NewString()
   201  	_, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName)
   202  	assert.NoError(t, err)
   203  	tenanantEdgeID := ""
   204  	row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID)
   205  	assert.NoError(t, row.Scan(&tenanantEdgeID))
   206  	bannerEdgeID := uuid.NewString()
   207  	projectID := uuid.NewString()
   208  	_, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner")
   209  	assert.NoError(t, err)
   210  	req, _ := http.NewRequest("GET", "/ping", nil)
   211  	req.Header.Set(bannerHeaderName, bannerEdgeID)
   212  	rtr.ServeHTTP(r, req)
   213  	assert.Equal(t, http.StatusOK, r.Code)
   214  }
   215  
   216  func TestHandleValidSession_EDGE_BANNER_VIEWER(t *testing.T) { //nolint
   217  	rtr := gin.Default()
   218  	r := httptest.NewRecorder()
   219  	_, ginTestEngine := getTestGinContext(r)
   220  	as, err := NewAuthServer(args, ginTestEngine)
   221  	assert.NoError(t, err)
   222  	mockSessions := session.NewMockSessions()
   223  	rtr.GET("/ping", func(c *gin.Context) {
   224  		as.handleValidSession(c, time.Now(), mockSessions)
   225  	})
   226  	bannerName := uuid.NewString()
   227  	mockSessions.Set("organization", bannerName)
   228  	mockSessions.Set("roles", []string{"EDGE_BANNER_VIEWER"})
   229  	mockSessions.Set("username", "test")
   230  	assert.NoError(t, mockSessions.Save())
   231  	tenantOrgID := uuid.NewString()
   232  	projectID := uuid.NewString()
   233  	_, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName)
   234  	assert.NoError(t, err)
   235  	tenanantEdgeID := ""
   236  	row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID)
   237  	assert.NoError(t, row.Scan(&tenanantEdgeID))
   238  	bannerEdgeID := uuid.NewString()
   239  	_, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner")
   240  	assert.NoError(t, err)
   241  	req, _ := http.NewRequest("GET", "/ping", nil)
   242  	req.Header.Set(bannerHeaderName, bannerEdgeID)
   243  	rtr.ServeHTTP(r, req)
   244  	assert.Equal(t, http.StatusOK, r.Code)
   245  }
   246  
   247  func TestHandleValidSession_EDGE_BANNER_VIEWER_grafana(t *testing.T) { //nolint
   248  	rtr := gin.Default()
   249  	r := httptest.NewRecorder()
   250  	ctx, ginTestEngine := getTestGinContext(r)
   251  	ctx.Request = &http.Request{
   252  		RequestURI: "/grafana/",
   253  	}
   254  	as, err := NewAuthServer(args, ginTestEngine)
   255  	assert.NoError(t, err)
   256  	mockSessions := session.NewMockSessions()
   257  	rtr.GET("/grafana/", func(c *gin.Context) {
   258  		as.handleValidSession(c, time.Now(), mockSessions)
   259  	})
   260  	bannerName := uuid.NewString()
   261  	mockSessions.Set("organization", bannerName)
   262  	mockSessions.Set("roles", []string{"EDGE_BANNER_VIEWER"})
   263  	mockSessions.Set("username", "test")
   264  	assert.NoError(t, mockSessions.Save())
   265  	tenantOrgID := uuid.NewString()
   266  	projectID := uuid.NewString()
   267  	_, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName)
   268  	assert.NoError(t, err)
   269  	tenanantEdgeID := ""
   270  	row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID)
   271  	assert.NoError(t, row.Scan(&tenanantEdgeID))
   272  	bannerEdgeID := uuid.NewString()
   273  	_, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner")
   274  	assert.NoError(t, err)
   275  	req, _ := http.NewRequest("GET", "/grafana/", nil)
   276  	req.Header.Set(bannerHeaderName, bannerEdgeID)
   277  	rtr.ServeHTTP(r, req)
   278  	assert.Equal(t, http.StatusOK, r.Code)
   279  }
   280  
   281  func TestHandleValidSession_VNC(t *testing.T) {
   282  	tests := map[string]struct {
   283  		url                  string
   284  		role                 string
   285  		requestHeaderPresent bool
   286  	}{
   287  		"/novnc/": {
   288  			url:                  "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/",
   289  			role:                 "EDGE_BANNER_ADMIN",
   290  			requestHeaderPresent: false,
   291  		},
   292  		"/novnc/ws": {
   293  			url:                  "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/write/ws",
   294  			role:                 "EDGE_BANNER_ADMIN",
   295  			requestHeaderPresent: false,
   296  		},
   297  		"/notnovnc/": {
   298  			url:                  "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/notnovnc/",
   299  			role:                 "EDGE_BANNER_ADMIN",
   300  			requestHeaderPresent: false,
   301  		},
   302  		"/novnc/authorize with EDGE_BANNER_ADMIN role": {
   303  			url:                  "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/read/authorize",
   304  			role:                 "EDGE_BANNER_ADMIN",
   305  			requestHeaderPresent: true,
   306  		},
   307  		"/novnc/authorize with EDGE_ORG_ADMIN role": {
   308  			url:                  "/remoteaccess/56490f8a-f692-44f3-bdc7-2a880e135d2d/novnc/write/authorize",
   309  			role:                 "EDGE_ORG_ADMIN",
   310  			requestHeaderPresent: true,
   311  		},
   312  	}
   313  
   314  	for name, tc := range tests {
   315  		tc := tc
   316  		t.Run(name, func(t *testing.T) {
   317  			// Setup
   318  			rtr := gin.Default()
   319  			r := httptest.NewRecorder()
   320  			ctx, ginTestEngine := getTestGinContext(r)
   321  			ctx.Request = &http.Request{
   322  				RequestURI: tc.url,
   323  			}
   324  			as, err := NewAuthServer(args, ginTestEngine)
   325  			assert.NoError(t, err)
   326  			mockSessions := session.NewMockSessions()
   327  			rtr.GET(strings.Split(tc.url, "?")[0], func(c *gin.Context) {
   328  				as.handleValidSession(c, time.Now(), mockSessions)
   329  			})
   330  			bannerName := uuid.NewString()
   331  			mockSessions.Set("organization", bannerName)
   332  			mockSessions.Set("roles", []string{tc.role})
   333  			username := "test"
   334  			mockSessions.Set("username", username)
   335  			assert.NoError(t, mockSessions.Save())
   336  			tenantOrgID := uuid.NewString()
   337  			projectID := uuid.NewString()
   338  			_, err = as.db.Exec(sql.TenantInsertQuery, tenantOrgID, bannerName)
   339  			assert.NoError(t, err)
   340  			tenanantEdgeID := ""
   341  			row := as.db.QueryRow("SELECT tenant_edge_id FROM tenants WHERE org_id = $1", tenantOrgID)
   342  			assert.NoError(t, row.Scan(&tenanantEdgeID))
   343  			bannerEdgeID := uuid.NewString()
   344  			_, err = as.db.Exec(sql.BannerInsertQuery, "test-banner-bsl-id", bannerName, "org", projectID, tenanantEdgeID, bannerEdgeID, "test banner")
   345  			assert.NoError(t, err)
   346  
   347  			// Test
   348  			req, err := http.NewRequest("GET", tc.url, nil)
   349  			assert.NoError(t, err)
   350  			req.Header.Set(bannerHeaderName, bannerEdgeID)
   351  			rtr.ServeHTTP(r, req)
   352  			assert.Equal(t, http.StatusOK, r.Code)
   353  
   354  			resultHeader := r.Result().Header
   355  			assert.Equal(t, username, resultHeader.Get("X-WEBAUTH-USER"))
   356  
   357  			requestID := resultHeader.Get(vncconst.HeaderKeyRequestID)
   358  			if tc.requestHeaderPresent {
   359  				_, err = uuid.Parse(requestID)
   360  				assert.NoError(t, err)
   361  			} else {
   362  				assert.Empty(t, requestID)
   363  			}
   364  		})
   365  	}
   366  }
   367  
   368  func getTestGinContext(r *httptest.ResponseRecorder) (*gin.Context, *gin.Engine) {
   369  	gin.SetMode(gin.TestMode)
   370  	ctx, ginEngine := gin.CreateTestContext(r)
   371  	return ctx, ginEngine
   372  }
   373  
   374  func TestCheckFilterFunc(t *testing.T) {
   375  	tests := map[string]struct {
   376  		checks     []check
   377  		statusCode int
   378  	}{
   379  		"No error": {
   380  			checks: []check{
   381  				{
   382  					pathFilter: "/request",
   383  					checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
   384  						return nil
   385  					},
   386  				},
   387  			},
   388  			statusCode: http.StatusOK,
   389  		},
   390  		"Error": {
   391  			checks: []check{
   392  				{
   393  					pathFilter: "/request",
   394  					checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
   395  						return fmt.Errorf("error")
   396  					},
   397  				},
   398  			},
   399  			statusCode: 500,
   400  		},
   401  		"Unauthorized": {
   402  			checks: []check{
   403  				{
   404  					pathFilter: "/request",
   405  					checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
   406  						return &httpError{
   407  							statusCode: http.StatusUnauthorized,
   408  							err:        fmt.Errorf("error"),
   409  						}
   410  					},
   411  				},
   412  			},
   413  			statusCode: 401,
   414  		},
   415  		"Incorrect Path": {
   416  			checks: []check{
   417  				{
   418  					pathFilter: "/otherpath",
   419  					checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
   420  						return fmt.Errorf("error")
   421  					},
   422  				},
   423  			},
   424  			statusCode: http.StatusOK,
   425  		},
   426  		"Partial Path Match": {
   427  			checks: []check{
   428  				{
   429  					pathFilter: "/req",
   430  					checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
   431  						return fmt.Errorf("error")
   432  					},
   433  				},
   434  			},
   435  			statusCode: http.StatusInternalServerError,
   436  		},
   437  		"Match All Paths": {
   438  			checks: []check{
   439  				{
   440  					pathFilter: "",
   441  					checkFunc: func(*AuthServer, *gin.Context, sessions.Session) error {
   442  						return fmt.Errorf("error")
   443  					},
   444  				},
   445  			},
   446  			statusCode: http.StatusInternalServerError,
   447  		},
   448  	}
   449  
   450  	for name, tc := range tests {
   451  		tc := tc
   452  		t.Run(name, func(t *testing.T) {
   453  			// Setup
   454  			rtr := gin.Default()
   455  			r := httptest.NewRecorder()
   456  			_, ginTestEngine := getTestGinContext(r)
   457  
   458  			db, err := seededPostgres.DB()
   459  			assert.NoError(t, err)
   460  
   461  			as := AuthServer{
   462  				GinMode:   gin.TestMode,
   463  				GinEngine: ginTestEngine,
   464  				db:        db,
   465  				Log:       fog.New(),
   466  				checks:    tc.checks,
   467  			}
   468  
   469  			mockSessions := session.NewMockSessions()
   470  			bannerName := uuid.NewString()
   471  			mockSessions.Set("organization", bannerName)
   472  			mockSessions.Set("roles", []string{"EDGE_ORG_ADMIN"})
   473  			username := "test"
   474  			mockSessions.Set("username", username)
   475  			assert.NoError(t, mockSessions.Save())
   476  
   477  			rtr.GET("/request", func(c *gin.Context) {
   478  				as.handleValidSession(c, time.Now(), mockSessions)
   479  			})
   480  
   481  			bannerEdgeID := uuid.NewString()
   482  
   483  			// Test
   484  			req, _ := http.NewRequest("GET", "/request", nil)
   485  			req.Header.Set(bannerHeaderName, bannerEdgeID)
   486  			rtr.ServeHTTP(r, req)
   487  			assert.Equal(t, tc.statusCode, r.Code)
   488  		})
   489  	}
   490  }
   491  

View as plain text