package host_test import ( "bytes" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "edge-infra.dev/pkg/sds/interlock/internal/errors" "edge-infra.dev/pkg/sds/interlock/topic/host" ) type TimeAssertionFunc func(t assert.TestingT, actual time.Time, msgAndArgs ...interface{}) bool func TimeEqual(expected time.Time) TimeAssertionFunc { return func(t assert.TestingT, actual time.Time, msgAndArgs ...interface{}) bool { return assert.Equal(t, expected, actual, msgAndArgs...) } } func WithinRange(start, end time.Time) TimeAssertionFunc { return func(t assert.TestingT, actual time.Time, msgAndArgs ...interface{}) bool { return assert.WithinRange(t, actual, start, end, msgAndArgs...) } } func vncStateAssertion(expected interface{}, timeAssert TimeAssertionFunc) func(*testing.T, interface{}) { return func(t *testing.T, actual interface{}) { expectedVNC := expected.(*host.State) actualVNC := actual.(*host.State) for i := range expectedVNC.VNC { actualTime, err := time.Parse(time.RFC3339, actualVNC.VNC[i].TimeStamp) assert.NoError(t, err) timeAssert(t, actualTime) // Now that we have tested actualTime, add it to expectedVNC so we can // do a full equal assertion. expectedVNC.VNC[i].TimeStamp = actualVNC.VNC[i].TimeStamp assert.Equal(t, expectedVNC.VNC[i], actualVNC.VNC[i]) } } } func genericAssertion(expected interface{}) func(*testing.T, interface{}) { return func(t *testing.T, actual interface{}) { assert.Equal(t, expected, actual) } } type payload struct { RequestID string `json:"requestId"` Status string `json:"status"` Connections int `json:"connections"` TimeStamp time.Time `json:"timestamp,omitempty" time_format:"2006-01-02T15:04:05Z07:00"` } func TestVNCPost(t *testing.T) { h, err := setupTestHostTopic(t, testHostname, "uid") require.NoError(t, err) r := gin.Default() h.RegisterEndpoints(r) // We need to round testTime to the nearest second because otherwise we will be testing // against some extra values that get lost in time.RFC3339 format, like milliseconds. // Timezone is set to UTC to avoid pipeline errors. testTime := time.Now().Add(48 * time.Hour).Round(time.Second).UTC() tests := map[string]struct { input payload expectedStatus int response interface{} assertEqual func(*testing.T, interface{}) }{ "UpdateVNC VNC state": { input: payload{ RequestID: "1", Status: string(host.Accepted), }, expectedStatus: http.StatusAccepted, response: &host.State{}, assertEqual: vncStateAssertion(&host.State{ Hostname: testHostname, VNC: host.VNCStates{ { RequestID: "1", Status: host.Accepted, }, }, NodeUID: "uid", }, WithinRange(time.Now().Add(-(10*time.Second)), time.Now())), }, "UpdateVNC VNC state with timestamp field": { input: payload{ RequestID: "1", Status: string(host.Accepted), TimeStamp: testTime, }, expectedStatus: http.StatusAccepted, response: &host.State{}, assertEqual: vncStateAssertion(&host.State{ Hostname: testHostname, VNC: host.VNCStates{ { RequestID: "1", Status: host.Accepted, TimeStamp: testTime.Format(time.RFC3339), }, }, NodeUID: "uid", }, TimeEqual(testTime)), }, "UpdateVNC VNC state to CONNECTED": { input: payload{ RequestID: "1", Status: string(host.Connected), Connections: 2, TimeStamp: testTime, }, expectedStatus: http.StatusAccepted, response: &host.State{}, assertEqual: vncStateAssertion(&host.State{ Hostname: testHostname, VNC: host.VNCStates{ { RequestID: "1", Status: host.Connected, Connections: 2, TimeStamp: testTime.Format(time.RFC3339), }, }, NodeUID: "uid", }, TimeEqual(testTime)), }, "Failure nil VNC Status": { input: payload{ RequestID: "1", Status: "", }, expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: "Status is required", }, }, }), }, "Failure INVALID VNC Status": { input: payload{ RequestID: "1", Status: "INVALID", }, expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: "Key: 'postVNCPayload.Status' Error:Field validation for 'Status' failed on the 'oneof' tag", }, }, }), }, "Failure CONNECTED status with no connections": { input: payload{ RequestID: "1", Status: "CONNECTED", }, expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: `Key: 'postVNCPayload.Connections' Error:Field validation for 'Connections' failed on the 'is_zero_xor_is_connected' tag`, }, }, }), }, "Failure Non-CONNECTED status with connections": { input: payload{ RequestID: "1", Status: "REQUESTED", Connections: 1, }, expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: `Key: 'postVNCPayload.Connections' Error:Field validation for 'Connections' failed on the 'is_zero_xor_is_connected' tag`, }, }, }), }, "Failure Negative Connections": { input: payload{ RequestID: "1", Status: "CONNECTED", Connections: -1, }, expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: `Key: 'postVNCPayload.Connections' Error:Field validation for 'Connections' failed on the 'gte' tag`, }, }, }), }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { out, err := json.Marshal(tc.input) require.NoError(t, err) req, err := http.NewRequest(http.MethodPost, host.Path+"/vnc", bytes.NewReader(out)) require.NoError(t, err) w := httptest.NewRecorder() r.ServeHTTP(w, req) require.Equal(t, tc.expectedStatus, w.Code) require.NoError(t, json.Unmarshal(w.Body.Bytes(), tc.response)) tc.assertEqual(t, tc.response) }) } } type putPayload []struct { RequestID string `json:"requestId"` Status string `json:"status"` TimeStamp time.Time `json:"timestamp"` } func TestVNCPut(t *testing.T) { h, err := setupTestHostTopic(t, testHostname, "a-uid") require.NoError(t, err) r := gin.Default() h.RegisterEndpoints(r) testTime := time.Now().Add(48 * time.Hour).Round(time.Second).UTC() tests := map[string]struct { setup putPayload input []byte expectedStatus int response interface{} assertEqual func(*testing.T, interface{}) }{ "No Payload": { expectedStatus: http.StatusInternalServerError, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: "Internal Server Error", }, }, }), }, "No Setup: Empty": { input: []byte(`[]`), expectedStatus: http.StatusAccepted, response: &host.State{}, assertEqual: vncStateAssertion(&host.State{ Hostname: testHostname, VNC: host.VNCStates{}, NodeUID: "uid", }, nil), }, "No Setup: Add Payload": { input: []byte(fmt.Sprintf(` [ { "requestID": "request1", "status": "ACCEPTED", "timestamp": "%[1]s" }, { "requestID": "request2", "status": "REQUESTED", "timestamp": "%[1]s" }, { "requestID": "request3", "status": "CONNECTED", "connections": 1, "timestamp": "%[1]s" } ]`, testTime.Format(time.RFC3339))), expectedStatus: http.StatusAccepted, response: &host.State{}, assertEqual: vncStateAssertion(&host.State{ Hostname: testHostname, VNC: host.VNCStates{ { RequestID: "request1", Status: host.Accepted, TimeStamp: testTime.Format(time.RFC3339), }, { RequestID: "request2", Status: host.Requested, TimeStamp: testTime.Format(time.RFC3339), }, { RequestID: "request3", Status: host.Connected, Connections: 1, TimeStamp: testTime.Format(time.RFC3339), }, }, NodeUID: "uid", }, TimeEqual(testTime)), }, "Setup: Empty": { setup: putPayload{ { RequestID: "request1", Status: string(host.Accepted), TimeStamp: testTime, }, { RequestID: "request2", Status: string(host.Requested), TimeStamp: testTime, }, }, input: []byte(`[]`), expectedStatus: http.StatusAccepted, response: &host.State{}, assertEqual: vncStateAssertion(&host.State{ Hostname: testHostname, VNC: host.VNCStates{}, NodeUID: "uid", }, TimeEqual(testTime)), }, "Setup: Add Payload": { setup: putPayload{ { RequestID: "request_to_be_updated", Status: string(host.Requested), TimeStamp: time.Now(), }, { RequestID: "request_to_be_removed", Status: string(host.Requested), TimeStamp: time.Now(), }, }, input: []byte(fmt.Sprintf(` [ { "requestID": "request_to_be_updated", "status": "ACCEPTED", "timestamp": "%[1]s" }, { "requestID": "request_to_be_added_1", "status": "ACCEPTED", "timestamp": "%[1]s" }, { "requestID": "request_to_be_added_2", "status": "CONNECTED", "connections": 1, "timestamp": "%[1]s" } ]`, testTime.Format(time.RFC3339))), expectedStatus: http.StatusAccepted, response: &host.State{}, assertEqual: vncStateAssertion(&host.State{ Hostname: testHostname, VNC: host.VNCStates{ { RequestID: "request_to_be_updated", Status: host.Accepted, TimeStamp: testTime.Format(time.RFC3339), }, { RequestID: "request_to_be_added_1", Status: host.Accepted, TimeStamp: testTime.Format(time.RFC3339), }, { RequestID: "request_to_be_added_2", Status: host.Connected, Connections: 1, TimeStamp: testTime.Format(time.RFC3339), }, }, NodeUID: "uid", }, TimeEqual(testTime)), }, "Fail: No Timestamp": { input: []byte(`[ { "requestID": "1", "status": "ACCEPTED" }]`), expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: "Key: 'putVNCPayload.TimeStamp' Error:Field validation for 'TimeStamp' failed on the 'required' tag", }, }, }, ), }, "Fail: Unsupported Status": { input: []byte(fmt.Sprintf(`[ { "requestID": "1", "status": "DROPPED", "timestamp": "%[1]s" }]`, testTime.Format(time.RFC3339))), expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: "Key: 'putVNCPayload.Status' Error:Field validation for 'Status' failed on the 'oneof' tag", }, }, }), }, "Fail: Multiple Binding Errors": { input: []byte(fmt.Sprintf(`[ { "requestID": "1", "status": "ACCEPTED" }, { "requestID": "2", "status": "INVALID", "timestamp": "%[1]s" }, { "requestID": "3", "status": "INVALID" }]`, testTime.Format(time.RFC3339))), expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: "Key: 'putVNCPayload.TimeStamp' Error:Field validation for 'TimeStamp' failed on the 'required' tag", }, { Detail: "Key: 'putVNCPayload.Status' Error:Field validation for 'Status' failed on the 'oneof' tag", }, { Detail: "Key: 'putVNCPayload.Status' Error:Field validation for 'Status' failed on the 'oneof' tag\nKey: 'putVNCPayload.TimeStamp' Error:Field validation for 'TimeStamp' failed on the 'required' tag", }, }, }), }, "Fail: RequestID Validation Errors": { input: []byte(fmt.Sprintf(`[ { "requestID": "1", "status": "ACCEPTED", "timestamp": "%[1]s" }, { "requestID": "1", "status": "ACCEPTED", "timestamp": "%[1]s" }, { "requestID": "1", "status": "ACCEPTED", "timestamp": "%[1]s" }]`, testTime.Format(time.RFC3339))), expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: "Key: 'putVNCPayload.RequestID' Error: Duplicate request id \"1\"", }, { Detail: "Key: 'putVNCPayload.RequestID' Error: Duplicate request id \"1\"", }, }, }), }, "Fail: Connections Validation Errors": { input: []byte(fmt.Sprintf(` [ { "requestID": "1", "status": "CONNECTED", "timeStamp": "%[1]s" }, { "requestID": "2", "status": "REQUESTED", "connections": 1, "timeStamp": "%[1]s" }, { "requestID": "3", "status": "CONNECTED", "connections": -1, "timeStamp": "%[1]s" } ]`, testTime.Format(time.RFC3339))), expectedStatus: http.StatusBadRequest, response: &Errors{}, assertEqual: genericAssertion(&Errors{ Errors: []*errors.Error{ { Detail: "Key: 'putVNCPayload.Connections' Error:Field validation for 'Connections' failed on the 'is_zero_xor_is_connected' tag", }, { Detail: "Key: 'putVNCPayload.Connections' Error:Field validation for 'Connections' failed on the 'is_zero_xor_is_connected' tag", }, { Detail: "Key: 'putVNCPayload.Connections' Error:Field validation for 'Connections' failed on the 'gte' tag", }, }, }), }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { // Set up a pre-defined state before test for _, setupPayload := range tc.setup { out, err := json.Marshal(setupPayload) require.NoError(t, err) req, err := http.NewRequest(http.MethodPost, host.Path+"/vnc", bytes.NewReader(out)) require.NoError(t, err) w := httptest.NewRecorder() r.ServeHTTP(w, req) require.Equal(t, http.StatusAccepted, w.Code) } req, err := http.NewRequest(http.MethodPut, host.Path+"/vnc", bytes.NewReader(tc.input)) require.NoError(t, err) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, tc.expectedStatus, w.Code) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), tc.response)) tc.assertEqual(t, tc.response) }) } }