package integration import ( "context" "database/sql" "fmt" "io" "net/http" "net/http/httptest" "os" "strings" "testing" "time" "edge-infra.dev/pkg/lib/fog" "edge-infra.dev/pkg/sds/emergencyaccess/authservice" "edge-infra.dev/pkg/sds/emergencyaccess/authservice/server" "edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database" datasql "edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database/sql" "edge-infra.dev/pkg/sds/emergencyaccess/eaconst" "edge-infra.dev/test/f2" "edge-infra.dev/test/f2/x/postgres" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type contextVal string var f f2.Framework // StringAssertionFunc can be used to make any kind of assertion against a string // // TODO: Given these are integration tests we should be using the provided ctx.Assert // // rather than passing in testing.T manually type StringAssertionFunc func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool func JSONEq(expected string) StringAssertionFunc { return func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool { return assert.JSONEq(t, expected, actual, msgAndArgs...) } } func StringEq(expected string) StringAssertionFunc { return func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool { return assert.Equal(t, expected, actual, msgAndArgs...) } } func TestMain(m *testing.M) { f = f2.New( context.Background(), f2.WithExtensions( postgres.New(), ), ). BeforeEachTest( func(ctx f2.Context, t *testing.T) (f2.Context, error) { return CreateTables(ctx, t) }, ) os.Exit(f.Run(m)) } func setupAuthservice(t *testing.T, db *sql.DB) *gin.Engine { gin.SetMode(gin.TestMode) router := gin.New() log := fog.New() config := authservice.Config{} ds := database.New(log, db) as, err := authservice.New(config, ds, nil) server := server.New(router, log, as) require.NoError(t, err) return server.GinEngine } // helper function which sets well known auth headers to any request func setAuthHeaders(req *http.Request) { req.Header.Set(eaconst.HeaderAuthKeyUsername, "username") req.Header.Set(eaconst.HeaderAuthKeyEmail, "email") req.Header.Set(eaconst.HeaderAuthKeyRoles, "role") req.Header.Set(eaconst.HeaderAuthKeyBanners, "banner") } func newAuthRequestWithContext(ctx context.Context, method string, url string, body io.Reader) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err } setAuthHeaders(req) return req, nil } type testCase struct { url string method string body io.Reader expectedStatus int expectedOut StringAssertionFunc } func testEndpoint(ctx f2.Context, t *testing.T, rulesEngine *gin.Engine, test testCase) f2.Context { t.Helper() r := httptest.NewRecorder() c, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() req, err := newAuthRequestWithContext(c, test.method, test.url, test.body) assert.NoError(t, err) rulesEngine.ServeHTTP(r, req) resp := r.Result() assert.Equal(t, test.expectedStatus, resp.StatusCode) test.expectedOut(t, r.Body.String()) return ctx } func TestQueries(t *testing.T) { var ( authService *gin.Engine ) feat := f2.NewFeature("Resolve Target"). Setup("Create Auth Service server", func(ctx f2.Context, t *testing.T) f2.Context { db := postgres.FromContextT(ctx, t).DB() authService = setupAuthservice(t, db) _ = authService return ctx }). Setup("Add some data", func(ctx f2.Context, t *testing.T) f2.Context { var ( db = postgres.FromContextT(ctx, t).DB() ) err := PopulateTables(ctx, db) require.NoError(t, err) return ctx }). Test("SelectProjectIDAndBannerID query", func(ctx f2.Context, t *testing.T) f2.Context { var ( db = postgres.FromContextT(ctx, t).DB() ) expectedProject := UUIDs["projects"][0] expectedBanner := UUIDs["banners"][0] rows, err := db.QueryContext(ctx, datasql.SelectProjectIDAndBannerID, expectedBanner, expectedBanner) require.NoError(t, err) var projectID, bannerID string require.True(t, rows.Next()) err = rows.Scan(&projectID, &bannerID) require.NoError(t, err) require.Equal(t, expectedProject, projectID) require.Equal(t, expectedBanner, bannerID) return ctx }). Test("SelectStoreID query", func(ctx f2.Context, t *testing.T) f2.Context { var ( db = postgres.FromContextT(ctx, t).DB() ) expected := UUIDs["clusters"][0] bannerID := UUIDs["banners"][0] rows, err := db.QueryContext(ctx, datasql.SelectStoreID, expected, expected, bannerID) require.NoError(t, err) var storeID string require.True(t, rows.Next()) err = rows.Scan(&storeID) require.NoError(t, err) require.Equal(t, expected, storeID) return ctx }). Test("SelectTerminalID query", func(ctx f2.Context, t *testing.T) f2.Context { var ( db = postgres.FromContextT(ctx, t).DB() ) expected := UUIDs["terminals"][0] storeID := UUIDs["clusters"][0] rows, err := db.QueryContext(ctx, datasql.SelectTerminalID, expected, expected, storeID) require.NoError(t, err) var terminalID string require.True(t, rows.Next()) err = rows.Scan(&terminalID) require.NoError(t, err) require.Equal(t, expected, terminalID) return ctx }). Feature() f.Test(t, feat) } func TestResolveTarget(t *testing.T) { var ( authService *gin.Engine ) feat := f2.NewFeature("Resolve Target"). Setup("Create Auth Service server", func(ctx f2.Context, t *testing.T) f2.Context { var db = postgres.FromContextT(ctx, t).DB() authService = setupAuthservice(t, db) _ = authService return ctx }). Setup("Add some data", func(ctx f2.Context, t *testing.T) f2.Context { var ( db = postgres.FromContextT(ctx, t).DB() ) err := PopulateTables(ctx, db) require.NoError(t, err) return ctx }). Test("Valid Target (IDs)", func(ctx f2.Context, t *testing.T) f2.Context { projectID := UUIDs["projects"][0] bannerID := UUIDs["banners"][0] storeID := UUIDs["clusters"][0] terminalID := UUIDs["terminals"][0] payload := fmt.Sprintf(`{ "target": { "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, bannerID, storeID, terminalID) expected := fmt.Sprintf(`{ "target": { "projectid": "%s", "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, projectID, bannerID, storeID, terminalID) test := testCase{ url: "/resolveTarget", method: http.MethodPost, body: strings.NewReader(payload), expectedStatus: http.StatusOK, expectedOut: JSONEq(expected), } ctx = testEndpoint(ctx, t, authService, test) return ctx }). Test("Valid Target (Names)", func(ctx f2.Context, t *testing.T) f2.Context { projectID := UUIDs["projects"][0] bannerID, bannerName := UUIDs["banners"][0], Names["banners"][0] storeID, storeName := UUIDs["clusters"][0], Names["clusters"][0] terminalID, terminalName := UUIDs["terminals"][0], Names["terminals"][0] payload := fmt.Sprintf(`{ "target": { "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, bannerName, storeName, terminalName) expected := fmt.Sprintf(`{ "target": { "projectid": "%s", "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, projectID, bannerID, storeID, terminalID) test := testCase{ url: "/resolveTarget", method: http.MethodPost, body: strings.NewReader(payload), expectedStatus: http.StatusOK, expectedOut: JSONEq(expected), } ctx = testEndpoint(ctx, t, authService, test) return ctx }). Test("Valid Target (Mixed)", func(ctx f2.Context, t *testing.T) f2.Context { projectID := UUIDs["projects"][0] bannerID, bannerName := UUIDs["banners"][0], Names["banners"][0] storeID := UUIDs["clusters"][0] terminalID, terminalName := UUIDs["terminals"][0], Names["terminals"][0] payload := fmt.Sprintf(`{ "target": { "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, bannerName, storeID, terminalName) expected := fmt.Sprintf(`{ "target": { "projectid": "%s", "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, projectID, bannerID, storeID, terminalID) test := testCase{ url: "/resolveTarget", method: http.MethodPost, body: strings.NewReader(payload), expectedStatus: http.StatusOK, expectedOut: JSONEq(expected), } ctx = testEndpoint(ctx, t, authService, test) return ctx }). Test("ProjectID In Payload", func(ctx f2.Context, t *testing.T) f2.Context { projectID := UUIDs["projects"][0] bannerID := UUIDs["banners"][0] storeID := UUIDs["clusters"][0] terminalID := UUIDs["terminals"][0] payload := fmt.Sprintf(`{ "target": { "projectid": "irrelevant-project-details", "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, bannerID, storeID, terminalID) expected := fmt.Sprintf(`{ "target": { "projectid": "%s", "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, projectID, bannerID, storeID, terminalID) test := testCase{ url: "/resolveTarget", method: http.MethodPost, body: strings.NewReader(payload), expectedStatus: http.StatusOK, expectedOut: JSONEq(expected), } ctx = testEndpoint(ctx, t, authService, test) return ctx }). Test("Invalid Banner", func(ctx f2.Context, t *testing.T) f2.Context { storeName := Names["clusters"][0] terminalName := Names["terminals"][0] payload := fmt.Sprintf(`{ "target": { "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, "an-invalid-banner", storeName, terminalName) expected := `{ "details": [ "Banner an-invalid-banner not found" ], "errorCode": 61202, "errorMessage": "Request Error - Invalid Target properties" }` test := testCase{ url: "/resolveTarget", method: http.MethodPost, body: strings.NewReader(payload), expectedStatus: http.StatusBadRequest, expectedOut: JSONEq(expected), } ctx = testEndpoint(ctx, t, authService, test) return ctx }). Test("Invalid Store", func(ctx f2.Context, t *testing.T) f2.Context { bannerName := Names["banners"][0] terminalName := Names["terminals"][0] payload := fmt.Sprintf(`{ "target": { "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, bannerName, "an-invalid-store", terminalName) expected := fmt.Sprintf(`{ "details": [ "Store an-invalid-store not found in given banner %s" ], "errorCode": 61202, "errorMessage": "Request Error - Invalid Target properties" }`, bannerName) test := testCase{ url: "/resolveTarget", method: http.MethodPost, body: strings.NewReader(payload), expectedStatus: http.StatusBadRequest, expectedOut: JSONEq(expected), } ctx = testEndpoint(ctx, t, authService, test) return ctx }). Test("Invalid Terminal", func(ctx f2.Context, t *testing.T) f2.Context { bannerName := Names["banners"][0] storeName := Names["clusters"][0] payload := fmt.Sprintf(`{ "target": { "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, bannerName, storeName, "an-invalid-terminal") expected := fmt.Sprintf(`{ "details": [ "Terminal an-invalid-terminal not found in given store %s and banner %s" ], "errorCode": 61202, "errorMessage": "Request Error - Invalid Target properties" }`, storeName, bannerName) test := testCase{ url: "/resolveTarget", method: http.MethodPost, body: strings.NewReader(payload), expectedStatus: http.StatusBadRequest, expectedOut: JSONEq(expected), } ctx = testEndpoint(ctx, t, authService, test) return ctx }). Feature() f.Test(t, feat) } func TestAuthDetailsVerification(t *testing.T) { var ( authService *gin.Engine validPayload string ) feat := f2.NewFeature("Resolve Target"). Setup("Create Auth Service server", func(ctx f2.Context, t *testing.T) f2.Context { var db = postgres.FromContextT(ctx, t).DB() authService = setupAuthservice(t, db) _ = authService return ctx }). Setup("Add some data", func(ctx f2.Context, t *testing.T) f2.Context { db := postgres.FromContextT(ctx, t).DB() err := PopulateTables(ctx, db) require.NoError(t, err) return ctx }). Setup("Set valid payload", func(ctx f2.Context, _ *testing.T) f2.Context { bannerID := UUIDs["banners"][0] storeID := UUIDs["clusters"][0] terminalID := UUIDs["terminals"][0] validPayload = fmt.Sprintf(`{ "target": { "bannerid": "%s", "storeid": "%s", "terminalid": "%s" } }`, bannerID, storeID, terminalID) return ctx }). Test("Without Auth Details", func(ctx f2.Context, t *testing.T) f2.Context { r := httptest.NewRecorder() c, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() req, err := http.NewRequestWithContext(c, http.MethodPost, "/resolveTarget", strings.NewReader(validPayload)) assert.NoError(t, err) authService.ServeHTTP(r, req) resp := r.Result() assert.Equal(t, http.StatusForbidden, resp.StatusCode) return ctx }). Test("With Auth Details", func(ctx f2.Context, t *testing.T) f2.Context { r := httptest.NewRecorder() c, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() req, err := http.NewRequestWithContext(c, http.MethodPost, "/resolveTarget", strings.NewReader(validPayload)) assert.NoError(t, err) setAuthHeaders(req) authService.ServeHTTP(r, req) resp := r.Result() assert.Equal(t, http.StatusOK, resp.StatusCode) return ctx }). Feature() f.Test(t, feat) }