package server import ( "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "path" "testing" "edge-infra.dev/pkg/sds/emergencyaccess/eaconst" "edge-infra.dev/pkg/sds/emergencyaccess/eagateway" "edge-infra.dev/pkg/sds/emergencyaccess/msgdata" "edge-infra.dev/pkg/sds/emergencyaccess/types" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) const ( badCommand = "false" ) var ( defaultBytes = []byte(` { "type": "Output", "exitCode": 0, "output": "hello\n", "timestamp": "01-01-2023 00:00:00", "duration": 0.1 }`) defaultAttrMap = map[string]string{ "bannerId": "banner", "storeId": "store", "terminalId": "terminal", "sessionId": "orderingKey", "identity": "identity", "version": "1.0", "signature": "signature", "request-message-uuid": "uuid", } defaultTarget = types.Target{ Projectid: "projectID", Bannerid: "bannerID", Storeid: "storeID", Terminalid: "terminalID", } errTestRCLIStartSessionFail = fmt.Errorf("TestRCLIStartSessionFail") ) // helper function which sets well known auth headers to any request func setAuthHeaders(req *http.Request) { req.Header.Set(eaconst.HeaderAuthKeyUsername, "user") req.Header.Set(eaconst.HeaderAuthKeyEmail, "email") req.Header.Set(eaconst.HeaderAuthKeyRoles, "role") req.Header.Set(eaconst.HeaderAuthKeyBanners, "banner") } /* Mock auth service server */ type httpmiddleware func(next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) type authserverOpts struct { middleware []httpmiddleware authorizeCommand func(w http.ResponseWriter, r *http.Request) authorizeRequest func(w http.ResponseWriter, r *http.Request) resolveTarget func(w http.ResponseWriter, r *http.Request) authorizeTarget func(w http.ResponseWriter, r *http.Request) authorizeUser func(w http.ResponseWriter, r *http.Request) } type Option func(opts *authserverOpts) // verifyUserAuthHeaders returns middleware which can verify the correct user // Auth headers have been set, and then passes on the request to the passed in // handler func. func verifyUserAuthHeaders(t *testing.T) httpmiddleware { return func(next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, []string{"user"}, r.Header.Values(eaconst.HeaderAuthKeyUsername)) assert.Equal(t, []string{"email"}, r.Header.Values(eaconst.HeaderAuthKeyEmail)) assert.Equal(t, []string{"role"}, r.Header.Values(eaconst.HeaderAuthKeyRoles)) assert.Equal(t, []string{"banner"}, r.Header.Values(eaconst.HeaderAuthKeyBanners)) next(w, r) } } } // WithMiddleware allows setting middleware to run on every endpoint. The first // middleware passed in is the last applied middleware to the request func WithMiddleware(middleware ...httpmiddleware) Option { return func(opts *authserverOpts) { opts.middleware = middleware } } func defaultAuthorizeRequest(status int) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if status != http.StatusOK { // If ok we don't want to write it immediately, otherwise any errors // would be hiddent as you can't write the header multiple times w.WriteHeader(status) } // First read the incoming body and find the command in the body bytes, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } var m map[string]map[string]json.RawMessage err = json.Unmarshal(bytes, &m) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } var n map[string]string err = json.Unmarshal(m["Request"]["Data"], &n) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } // if the command is false we want to send an unauthorized response if n["command"] == "false" { // Although we return statusUnauthorized here, the error code // corresponds to status forbidden. This tests the functionality of // the apierror handler w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{ "errorCode": 62001, "ErrorMessage": "User Authorization Failure - User not permitted to perform this action" }`)) return } // Otherwise generate a valid request message and respond with a valid // response payload req, err := msgdata.NewV1_0Request(defaultCommand) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } d, err := req.Data() if err != nil { w.WriteHeader(http.StatusInternalServerError) return } data := authRequestResponse{ Request: struct { Data json.RawMessage Attributes map[string]string }{ Data: json.RawMessage(d), Attributes: req.Attributes(), }, } resp, err := json.Marshal(data) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } _, err = w.Write(resp) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } } } func WithAuthorizeCommand(authorizeCommand func(w http.ResponseWriter, r *http.Request)) Option { return func(opts *authserverOpts) { opts.authorizeCommand = authorizeCommand } } func defaultAuthorizeCommand(status int) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(status) bytes, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } var m map[string]interface{} err = json.Unmarshal(bytes, &m) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } command, ok := m["Command"].(string) if !ok { w.WriteHeader(http.StatusInternalServerError) return } validation := eagateway.CommandValidation{Valid: true} if command == badCommand { validation.Valid = false } resp, err := json.Marshal(validation) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } _, err = w.Write(resp) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } } } // Optionally override the default resolve target endpoint function func WithResolveTarget(resolveTarget func(w http.ResponseWriter, r *http.Request)) Option { return func(opts *authserverOpts) { opts.resolveTarget = resolveTarget } } func defaultResolveTarget() func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, _ *http.Request) { data, _ := json.Marshal(map[string]types.Target{ "target": defaultTarget, }) _, _ = w.Write(data) } } // Optionally override the default authorize target endpoint function func WithAuthorizeTarget(authorizeTarget func(w http.ResponseWriter, r *http.Request)) Option { return func(opts *authserverOpts) { opts.authorizeTarget = authorizeTarget } } func defaultAuthorizeTarget(status int) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(status) } } func WithAuthorizeUser(authorizeUser func(w http.ResponseWriter, r *http.Request)) Option { return func(opts *authserverOpts) { opts.authorizeUser = authorizeUser } } func defaultAuthorizeUser(status int) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(status) } } func authserviceServer(status int, opts ...Option) (server *httptest.Server, url string) { opt := authserverOpts{ authorizeCommand: defaultAuthorizeCommand(status), authorizeRequest: defaultAuthorizeRequest(status), authorizeTarget: defaultAuthorizeTarget(status), resolveTarget: defaultResolveTarget(), authorizeUser: defaultAuthorizeUser(status), } for _, o := range opts { o(&opt) } for _, mid := range opt.middleware { opt.authorizeCommand = mid(opt.authorizeCommand) opt.authorizeRequest = mid(opt.authorizeRequest) opt.authorizeTarget = mid(opt.authorizeTarget) opt.resolveTarget = mid(opt.resolveTarget) } mux := http.NewServeMux() mux.HandleFunc("/authservice/authorizeCommand", opt.authorizeCommand) mux.HandleFunc("/authservice/authorizeRequest", opt.authorizeRequest) mux.HandleFunc("/authservice/authorizeTarget", opt.authorizeTarget) mux.HandleFunc("/authservice/resolveTarget", opt.resolveTarget) mux.HandleFunc("/authservice/authorizeUser", opt.authorizeUser) server = httptest.NewServer(mux) url = path.Join(server.URL[7:], "authservice") return server, url } /* End mock */ func TestServerStatusEndpoints(t *testing.T) { tests := map[string]struct { query string expRes string expCode int }{ "Ready Returns Ok": { "/ready", `ok`, 200, }, "Health Returns Ok": { "/health", `ok`, 200, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { r := httptest.NewRecorder() // Create Gin context in test mode gin.SetMode(gin.TestMode) _, ginEngine := gin.CreateTestContext(r) // Create new GatewayServer. Use nil for rcli and requestservice as // this helps guarantee that the health endpoints don't make // spurious calls to these components. // Choose localhost as this does not require slow dns resolution, and no // authservice should be running on localhost during this test _, err := New(eagateway.Config{AuthServiceHost: "localhost"}, ginEngine, newLogger(), nil, nil) assert.NoError(t, err) // Send test query req, err := http.NewRequest(http.MethodGet, tc.query, nil) assert.NoError(t, err) ginEngine.ServeHTTP(r, req) // Retrieve response res := r.Result() assert.Equal(t, tc.expCode, res.StatusCode) data, err := io.ReadAll(r.Body) assert.NoError(t, err) assert.Equal(t, tc.expRes, string(data)) }) } }