package middleware import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/lib/fog" "edge-infra.dev/pkg/sds/emergencyaccess/types" ) func TestSaveAuth(t *testing.T) { t.Parallel() tests := map[string]struct { addReqHeaders func(*http.Request) user types.User ok bool }{ "Comma separated headers": { // Comma separated headers are interpreted as a single value addReqHeaders: func(r *http.Request) { r.Header.Add("X-Auth-Username", "abcd") r.Header.Set("X-Auth-Email", "abcd@efgh.xyz") r.Header.Set("X-Auth-Roles", "role1,role2, role3") r.Header.Set("X-Auth-Banners", "banner1,banner2") }, user: types.User{ Username: "abcd", Email: "abcd@efgh.xyz", Roles: []string{"role1,role2, role3"}, Banners: []string{"banner1,banner2"}, }, ok: true, }, "Multiple headers": { // Multiple role and banner headers are interpreted as multiple // entries in the slice. // Only the first username and email header is read addReqHeaders: func(r *http.Request) { r.Header.Add("X-Auth-Username", "abcd") r.Header.Add("X-Auth-Username", "efgh") r.Header.Set("X-Auth-Email", "abcd@efgh.xyz") r.Header.Add("X-Auth-Email", "efgh@ijkl.xyz") r.Header.Set("X-Auth-Roles", "role1") r.Header.Add("X-Auth-Roles", "role2") r.Header.Add("X-Auth-Banners", "banner1") r.Header.Add("X-Auth-Banners", "banner2") }, user: types.User{ Username: "abcd", Email: "abcd@efgh.xyz", Roles: []string{"role1", "role2"}, Banners: []string{"banner1", "banner2"}, }, ok: true, }, "Missing Headers": { // Missing headers stores an empty User in the context addReqHeaders: func(_ *http.Request) {}, user: types.User{}, ok: true, }, "Only Some Headers": { // Currently if only some headers have been set in the incoming // request we do a best effort of setting up the User and saving it // to context addReqHeaders: func(r *http.Request) { r.Header.Add("X-Auth-Username", "abcd") r.Header.Set("X-Auth-Roles", "role1") }, user: types.User{ Username: "abcd", Email: "", Roles: []string{"role1"}, Banners: nil, }, ok: true, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() r := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) tc.addReqHeaders(req) _, router := gin.CreateTestContext(r) router.ContextWithFallback = true router.Use(SaveAuthToContext()) var user types.User var ok bool router.Any("/", func(ctx *gin.Context) { user, ok = types.UserFromContext(ctx) }) router.ServeHTTP(r, req) assert.Equal(t, tc.user, user) assert.Equal(t, tc.ok, ok) }) } } func TestVerifyUserDetailsInContext(t *testing.T) { t.Parallel() r := httptest.NewRecorder() _, router := gin.CreateTestContext(r) router.ContextWithFallback = true router.Use(VerifyUserDetailsInContext()) user := types.User{ Username: "username", Email: "email@domain", } ctx := types.UserIntoContext(context.Background(), user) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil) assert.NoError(t, err) router.Any("/", func(ctx *gin.Context) { ctx.Status(http.StatusOK) }) router.ServeHTTP(r, req) assert.Equal(t, http.StatusOK, r.Result().StatusCode) } func TestVerifyUserDetailsInContextFail(t *testing.T) { t.Parallel() tests := map[string]struct { user types.User expMsg string }{ "No username": { user: types.User{ Email: "email@domain", }, expMsg: "60001: User Authorization Failure - User not found. Error: username not found in authservice request", }, "No email": { user: types.User{ Username: "username", }, expMsg: "60001: User Authorization Failure - User not found. Error: email not found in authservice request", }, "Neither username nor email": { expMsg: "60001: User Authorization Failure - User not found. Error: username not found in authservice request\nemail not found in authservice request", }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() r := httptest.NewRecorder() _, router := gin.CreateTestContext(r) router.ContextWithFallback = true router.Use(VerifyUserDetailsInContext()) buf := bytes.Buffer{} ctx := fog.IntoContext(context.Background(), fog.New(fog.To(&buf))) ctx = types.UserIntoContext(ctx, tc.user) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil) assert.NoError(t, err) router.Any("/", func(ctx *gin.Context) { ctx.Status(http.StatusOK) }) router.ServeHTTP(r, req) m := map[string]interface{}{} err = json.Unmarshal(buf.Bytes(), &m) assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, r.Result().StatusCode) assert.Equal(t, tc.expMsg, m["error"]) }) } }