...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/authservice/integration/auth_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/authservice/integration

     1  package integration
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"os"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"edge-infra.dev/pkg/lib/fog"
    16  	"edge-infra.dev/pkg/sds/emergencyaccess/authservice"
    17  	"edge-infra.dev/pkg/sds/emergencyaccess/authservice/server"
    18  	"edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database"
    19  	datasql "edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database/sql"
    20  	"edge-infra.dev/pkg/sds/emergencyaccess/eaconst"
    21  	"edge-infra.dev/test/f2"
    22  	"edge-infra.dev/test/f2/x/postgres"
    23  
    24  	"github.com/gin-gonic/gin"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  type contextVal string
    30  
    31  var f f2.Framework
    32  
    33  // StringAssertionFunc can be used to make any kind of assertion against a string
    34  //
    35  // TODO: Given these are integration tests we should be using the provided ctx.Assert
    36  //
    37  //	rather than passing in testing.T manually
    38  type StringAssertionFunc func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool
    39  
    40  func JSONEq(expected string) StringAssertionFunc {
    41  	return func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool {
    42  		return assert.JSONEq(t, expected, actual, msgAndArgs...)
    43  	}
    44  }
    45  
    46  func StringEq(expected string) StringAssertionFunc {
    47  	return func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool {
    48  		return assert.Equal(t, expected, actual, msgAndArgs...)
    49  	}
    50  }
    51  
    52  func TestMain(m *testing.M) {
    53  	f = f2.New(
    54  		context.Background(),
    55  		f2.WithExtensions(
    56  			postgres.New(),
    57  		),
    58  	).
    59  		BeforeEachTest(
    60  			func(ctx f2.Context, t *testing.T) (f2.Context, error) {
    61  				return CreateTables(ctx, t)
    62  			},
    63  		)
    64  
    65  	os.Exit(f.Run(m))
    66  }
    67  
    68  func setupAuthservice(t *testing.T, db *sql.DB) *gin.Engine {
    69  	gin.SetMode(gin.TestMode)
    70  	router := gin.New()
    71  
    72  	log := fog.New()
    73  
    74  	config := authservice.Config{}
    75  	ds := database.New(log, db)
    76  	as, err := authservice.New(config, ds, nil)
    77  	server := server.New(router, log, as)
    78  	require.NoError(t, err)
    79  
    80  	return server.GinEngine
    81  }
    82  
    83  // helper function which sets well known auth headers to any request
    84  func setAuthHeaders(req *http.Request) {
    85  	req.Header.Set(eaconst.HeaderAuthKeyUsername, "username")
    86  	req.Header.Set(eaconst.HeaderAuthKeyEmail, "email")
    87  	req.Header.Set(eaconst.HeaderAuthKeyRoles, "role")
    88  	req.Header.Set(eaconst.HeaderAuthKeyBanners, "banner")
    89  }
    90  
    91  func newAuthRequestWithContext(ctx context.Context, method string, url string, body io.Reader) (*http.Request, error) {
    92  	req, err := http.NewRequestWithContext(ctx, method, url, body)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	setAuthHeaders(req)
    97  	return req, nil
    98  }
    99  
   100  type testCase struct {
   101  	url            string
   102  	method         string
   103  	body           io.Reader
   104  	expectedStatus int
   105  	expectedOut    StringAssertionFunc
   106  }
   107  
   108  func testEndpoint(ctx f2.Context, t *testing.T, rulesEngine *gin.Engine, test testCase) f2.Context {
   109  	t.Helper()
   110  
   111  	r := httptest.NewRecorder()
   112  	c, cancel := context.WithTimeout(ctx, 10*time.Second)
   113  	defer cancel()
   114  	req, err := newAuthRequestWithContext(c, test.method, test.url, test.body)
   115  	assert.NoError(t, err)
   116  
   117  	rulesEngine.ServeHTTP(r, req)
   118  
   119  	resp := r.Result()
   120  
   121  	assert.Equal(t, test.expectedStatus, resp.StatusCode)
   122  	test.expectedOut(t, r.Body.String())
   123  
   124  	return ctx
   125  }
   126  
   127  func TestQueries(t *testing.T) {
   128  	var (
   129  		authService *gin.Engine
   130  	)
   131  
   132  	feat := f2.NewFeature("Resolve Target").
   133  		Setup("Create Auth Service server", func(ctx f2.Context, t *testing.T) f2.Context {
   134  			db := postgres.FromContextT(ctx, t).DB()
   135  			authService = setupAuthservice(t, db)
   136  			_ = authService
   137  			return ctx
   138  		}).
   139  		Setup("Add some data", func(ctx f2.Context, t *testing.T) f2.Context {
   140  			var (
   141  				db = postgres.FromContextT(ctx, t).DB()
   142  			)
   143  
   144  			err := PopulateTables(ctx, db)
   145  			require.NoError(t, err)
   146  
   147  			return ctx
   148  		}).
   149  		Test("SelectProjectIDAndBannerID query", func(ctx f2.Context, t *testing.T) f2.Context {
   150  			var (
   151  				db = postgres.FromContextT(ctx, t).DB()
   152  			)
   153  
   154  			expectedProject := UUIDs["projects"][0]
   155  			expectedBanner := UUIDs["banners"][0]
   156  			rows, err := db.QueryContext(ctx, datasql.SelectProjectIDAndBannerID, expectedBanner, expectedBanner)
   157  			require.NoError(t, err)
   158  
   159  			var projectID, bannerID string
   160  			require.True(t, rows.Next())
   161  			err = rows.Scan(&projectID, &bannerID)
   162  			require.NoError(t, err)
   163  
   164  			require.Equal(t, expectedProject, projectID)
   165  			require.Equal(t, expectedBanner, bannerID)
   166  
   167  			return ctx
   168  		}).
   169  		Test("SelectStoreID query", func(ctx f2.Context, t *testing.T) f2.Context {
   170  			var (
   171  				db = postgres.FromContextT(ctx, t).DB()
   172  			)
   173  
   174  			expected := UUIDs["clusters"][0]
   175  			bannerID := UUIDs["banners"][0]
   176  			rows, err := db.QueryContext(ctx, datasql.SelectStoreID, expected, expected, bannerID)
   177  			require.NoError(t, err)
   178  
   179  			var storeID string
   180  			require.True(t, rows.Next())
   181  			err = rows.Scan(&storeID)
   182  			require.NoError(t, err)
   183  
   184  			require.Equal(t, expected, storeID)
   185  
   186  			return ctx
   187  		}).
   188  		Test("SelectTerminalID query", func(ctx f2.Context, t *testing.T) f2.Context {
   189  			var (
   190  				db = postgres.FromContextT(ctx, t).DB()
   191  			)
   192  
   193  			expected := UUIDs["terminals"][0]
   194  			storeID := UUIDs["clusters"][0]
   195  			rows, err := db.QueryContext(ctx, datasql.SelectTerminalID, expected, expected, storeID)
   196  			require.NoError(t, err)
   197  
   198  			var terminalID string
   199  			require.True(t, rows.Next())
   200  			err = rows.Scan(&terminalID)
   201  			require.NoError(t, err)
   202  
   203  			require.Equal(t, expected, terminalID)
   204  
   205  			return ctx
   206  		}).
   207  		Feature()
   208  
   209  	f.Test(t, feat)
   210  }
   211  
   212  func TestResolveTarget(t *testing.T) {
   213  	var (
   214  		authService *gin.Engine
   215  	)
   216  
   217  	feat := f2.NewFeature("Resolve Target").
   218  		Setup("Create Auth Service server", func(ctx f2.Context, t *testing.T) f2.Context {
   219  			var db = postgres.FromContextT(ctx, t).DB()
   220  			authService = setupAuthservice(t, db)
   221  			_ = authService
   222  			return ctx
   223  		}).
   224  		Setup("Add some data", func(ctx f2.Context, t *testing.T) f2.Context {
   225  			var (
   226  				db = postgres.FromContextT(ctx, t).DB()
   227  			)
   228  
   229  			err := PopulateTables(ctx, db)
   230  			require.NoError(t, err)
   231  
   232  			return ctx
   233  		}).
   234  		Test("Valid Target (IDs)", func(ctx f2.Context, t *testing.T) f2.Context {
   235  			projectID := UUIDs["projects"][0]
   236  			bannerID := UUIDs["banners"][0]
   237  			storeID := UUIDs["clusters"][0]
   238  			terminalID := UUIDs["terminals"][0]
   239  
   240  			payload := fmt.Sprintf(`{
   241  				"target": {
   242  					"bannerid": "%s",
   243  					"storeid": "%s",
   244  					"terminalid": "%s"
   245  				}
   246  			}`, bannerID, storeID, terminalID)
   247  			expected := fmt.Sprintf(`{
   248  				"target": {
   249  					"projectid": "%s",
   250  					"bannerid": "%s",
   251  					"storeid": "%s",
   252  					"terminalid": "%s"
   253  				}
   254  			}`, projectID, bannerID, storeID, terminalID)
   255  
   256  			test := testCase{
   257  				url:            "/resolveTarget",
   258  				method:         http.MethodPost,
   259  				body:           strings.NewReader(payload),
   260  				expectedStatus: http.StatusOK,
   261  				expectedOut:    JSONEq(expected),
   262  			}
   263  			ctx = testEndpoint(ctx, t, authService, test)
   264  			return ctx
   265  		}).
   266  		Test("Valid Target (Names)", func(ctx f2.Context, t *testing.T) f2.Context {
   267  			projectID := UUIDs["projects"][0]
   268  			bannerID, bannerName := UUIDs["banners"][0], Names["banners"][0]
   269  			storeID, storeName := UUIDs["clusters"][0], Names["clusters"][0]
   270  			terminalID, terminalName := UUIDs["terminals"][0], Names["terminals"][0]
   271  
   272  			payload := fmt.Sprintf(`{
   273  				"target": {
   274  					"bannerid": "%s",
   275  					"storeid": "%s",
   276  					"terminalid": "%s"
   277  				}
   278  			}`, bannerName, storeName, terminalName)
   279  			expected := fmt.Sprintf(`{
   280  				"target": {
   281  					"projectid": "%s",
   282  					"bannerid": "%s",
   283  					"storeid": "%s",
   284  					"terminalid": "%s"
   285  				}
   286  			}`, projectID, bannerID, storeID, terminalID)
   287  
   288  			test := testCase{
   289  				url:            "/resolveTarget",
   290  				method:         http.MethodPost,
   291  				body:           strings.NewReader(payload),
   292  				expectedStatus: http.StatusOK,
   293  				expectedOut:    JSONEq(expected),
   294  			}
   295  
   296  			ctx = testEndpoint(ctx, t, authService, test)
   297  			return ctx
   298  		}).
   299  		Test("Valid Target (Mixed)", func(ctx f2.Context, t *testing.T) f2.Context {
   300  			projectID := UUIDs["projects"][0]
   301  			bannerID, bannerName := UUIDs["banners"][0], Names["banners"][0]
   302  			storeID := UUIDs["clusters"][0]
   303  			terminalID, terminalName := UUIDs["terminals"][0], Names["terminals"][0]
   304  
   305  			payload := fmt.Sprintf(`{
   306  				"target": {
   307  					"bannerid": "%s",
   308  					"storeid": "%s",
   309  					"terminalid": "%s"
   310  				}
   311  			}`, bannerName, storeID, terminalName)
   312  			expected := fmt.Sprintf(`{
   313  				"target": {
   314  					"projectid": "%s",
   315  					"bannerid": "%s",
   316  					"storeid": "%s",
   317  					"terminalid": "%s"
   318  				}
   319  			}`, projectID, bannerID, storeID, terminalID)
   320  
   321  			test := testCase{
   322  				url:            "/resolveTarget",
   323  				method:         http.MethodPost,
   324  				body:           strings.NewReader(payload),
   325  				expectedStatus: http.StatusOK,
   326  				expectedOut:    JSONEq(expected),
   327  			}
   328  			ctx = testEndpoint(ctx, t, authService, test)
   329  			return ctx
   330  		}).
   331  		Test("ProjectID In Payload", func(ctx f2.Context, t *testing.T) f2.Context {
   332  			projectID := UUIDs["projects"][0]
   333  			bannerID := UUIDs["banners"][0]
   334  			storeID := UUIDs["clusters"][0]
   335  			terminalID := UUIDs["terminals"][0]
   336  
   337  			payload := fmt.Sprintf(`{
   338  				"target": {
   339  					"projectid": "irrelevant-project-details",
   340  					"bannerid": "%s",
   341  					"storeid": "%s",
   342  					"terminalid": "%s"
   343  				}
   344  			}`, bannerID, storeID, terminalID)
   345  			expected := fmt.Sprintf(`{
   346  				"target": {
   347  					"projectid": "%s",
   348  					"bannerid": "%s",
   349  					"storeid": "%s",
   350  					"terminalid": "%s"
   351  				}
   352  			}`, projectID, bannerID, storeID, terminalID)
   353  
   354  			test := testCase{
   355  				url:            "/resolveTarget",
   356  				method:         http.MethodPost,
   357  				body:           strings.NewReader(payload),
   358  				expectedStatus: http.StatusOK,
   359  				expectedOut:    JSONEq(expected),
   360  			}
   361  			ctx = testEndpoint(ctx, t, authService, test)
   362  			return ctx
   363  		}).
   364  		Test("Invalid Banner", func(ctx f2.Context, t *testing.T) f2.Context {
   365  			storeName := Names["clusters"][0]
   366  			terminalName := Names["terminals"][0]
   367  
   368  			payload := fmt.Sprintf(`{
   369  				"target": {
   370  					"bannerid": "%s",
   371  					"storeid": "%s",
   372  					"terminalid": "%s"
   373  				}
   374  			}`, "an-invalid-banner", storeName, terminalName)
   375  			expected := `{
   376  				"details": [
   377  					"Banner an-invalid-banner not found"
   378  				],
   379  				"errorCode": 61202,
   380  				"errorMessage": "Request Error - Invalid Target properties" 
   381  			}`
   382  
   383  			test := testCase{
   384  				url:            "/resolveTarget",
   385  				method:         http.MethodPost,
   386  				body:           strings.NewReader(payload),
   387  				expectedStatus: http.StatusBadRequest,
   388  				expectedOut:    JSONEq(expected),
   389  			}
   390  			ctx = testEndpoint(ctx, t, authService, test)
   391  			return ctx
   392  		}).
   393  		Test("Invalid Store", func(ctx f2.Context, t *testing.T) f2.Context {
   394  			bannerName := Names["banners"][0]
   395  			terminalName := Names["terminals"][0]
   396  
   397  			payload := fmt.Sprintf(`{
   398  				"target": {
   399  					"bannerid": "%s",
   400  					"storeid": "%s",
   401  					"terminalid": "%s"
   402  				}
   403  			}`, bannerName, "an-invalid-store", terminalName)
   404  			expected := fmt.Sprintf(`{
   405  				"details": [
   406  					"Store an-invalid-store not found in given banner %s"
   407  				],
   408  				"errorCode": 61202,
   409  				"errorMessage": "Request Error - Invalid Target properties" 
   410  			}`, bannerName)
   411  
   412  			test := testCase{
   413  				url:            "/resolveTarget",
   414  				method:         http.MethodPost,
   415  				body:           strings.NewReader(payload),
   416  				expectedStatus: http.StatusBadRequest,
   417  				expectedOut:    JSONEq(expected),
   418  			}
   419  			ctx = testEndpoint(ctx, t, authService, test)
   420  			return ctx
   421  		}).
   422  		Test("Invalid Terminal", func(ctx f2.Context, t *testing.T) f2.Context {
   423  			bannerName := Names["banners"][0]
   424  			storeName := Names["clusters"][0]
   425  
   426  			payload := fmt.Sprintf(`{
   427  				"target": {
   428  					"bannerid": "%s",
   429  					"storeid": "%s",
   430  					"terminalid": "%s"
   431  				}
   432  			}`, bannerName, storeName, "an-invalid-terminal")
   433  			expected := fmt.Sprintf(`{
   434  				"details": [
   435  					"Terminal an-invalid-terminal not found in given store %s and banner %s"
   436  				],
   437  				"errorCode": 61202,
   438  				"errorMessage": "Request Error - Invalid Target properties" 
   439  			}`, storeName, bannerName)
   440  
   441  			test := testCase{
   442  				url:            "/resolveTarget",
   443  				method:         http.MethodPost,
   444  				body:           strings.NewReader(payload),
   445  				expectedStatus: http.StatusBadRequest,
   446  				expectedOut:    JSONEq(expected),
   447  			}
   448  			ctx = testEndpoint(ctx, t, authService, test)
   449  			return ctx
   450  		}).
   451  		Feature()
   452  
   453  	f.Test(t, feat)
   454  }
   455  
   456  func TestAuthDetailsVerification(t *testing.T) {
   457  	var (
   458  		authService  *gin.Engine
   459  		validPayload string
   460  	)
   461  
   462  	feat := f2.NewFeature("Resolve Target").
   463  		Setup("Create Auth Service server", func(ctx f2.Context, t *testing.T) f2.Context {
   464  			var db = postgres.FromContextT(ctx, t).DB()
   465  			authService = setupAuthservice(t, db)
   466  			_ = authService
   467  			return ctx
   468  		}).
   469  		Setup("Add some data", func(ctx f2.Context, t *testing.T) f2.Context {
   470  			db := postgres.FromContextT(ctx, t).DB()
   471  			err := PopulateTables(ctx, db)
   472  			require.NoError(t, err)
   473  			return ctx
   474  		}).
   475  		Setup("Set valid payload", func(ctx f2.Context, _ *testing.T) f2.Context {
   476  			bannerID := UUIDs["banners"][0]
   477  			storeID := UUIDs["clusters"][0]
   478  			terminalID := UUIDs["terminals"][0]
   479  			validPayload = fmt.Sprintf(`{
   480  				"target": {
   481  					"bannerid": "%s",
   482  					"storeid": "%s",
   483  					"terminalid": "%s"
   484  				}
   485  			}`, bannerID, storeID, terminalID)
   486  			return ctx
   487  		}).
   488  		Test("Without Auth Details", func(ctx f2.Context, t *testing.T) f2.Context {
   489  			r := httptest.NewRecorder()
   490  			c, cancel := context.WithTimeout(ctx, 10*time.Second)
   491  			defer cancel()
   492  
   493  			req, err := http.NewRequestWithContext(c, http.MethodPost, "/resolveTarget", strings.NewReader(validPayload))
   494  			assert.NoError(t, err)
   495  			authService.ServeHTTP(r, req)
   496  
   497  			resp := r.Result()
   498  			assert.Equal(t, http.StatusForbidden, resp.StatusCode)
   499  			return ctx
   500  		}).
   501  		Test("With Auth Details", func(ctx f2.Context, t *testing.T) f2.Context {
   502  			r := httptest.NewRecorder()
   503  			c, cancel := context.WithTimeout(ctx, 10*time.Second)
   504  			defer cancel()
   505  
   506  			req, err := http.NewRequestWithContext(c, http.MethodPost, "/resolveTarget", strings.NewReader(validPayload))
   507  			assert.NoError(t, err)
   508  			setAuthHeaders(req)
   509  			authService.ServeHTTP(r, req)
   510  
   511  			resp := r.Result()
   512  			assert.Equal(t, http.StatusOK, resp.StatusCode)
   513  			return ctx
   514  		}).
   515  		Feature()
   516  
   517  	f.Test(t, feat)
   518  }
   519  

View as plain text