package middleware import ( "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "edge-infra.dev/pkg/edge/client" "edge-infra.dev/pkg/lib/runtime/version" "github.com/gin-gonic/gin" assertAPI "github.com/stretchr/testify/assert" ) func init() { gin.SetMode(gin.ReleaseMode) } var v = version.New() var serverVersion = v.SemVer const token = "test-token" const bearer = "bearer" func TestCheckApiVersionOlder(t *testing.T) { assert := assertAPI.New(t) edgeVersion := "104.3.6" router := gin.Default() router.Use(CheckVersion()) router.Use(func(c *gin.Context) { auth := c.GetHeader(client.Authorization) assert.Equal(auth, bearer+" "+token) assert.Equal(c.GetHeader(client.EdgeVersion), edgeVersion) c.Status(200) }) server := httptest.NewServer(router) headers := make(map[string]string) headers[client.Authorization] = fmt.Sprintf("%s %s", client.BearerToken, token) headers[client.EdgeVersion] = edgeVersion tokenclient := client.NewTokenClient(headers) res, err := tokenclient.Get(server.URL) assert.NoError(err) assert.Equal(http.StatusBadRequest, res.StatusCode) resp, err := io.ReadAll(res.Body) assert.NoError(err) assert.True(strings.Contains(string(resp), fmt.Sprintf("client version %s incompatible with server version %s", edgeVersion, serverVersion))) } func TestCheckApiVersionNewer(t *testing.T) { assert := assertAPI.New(t) edgeVersion := "0.3.0" router := gin.Default() router.Use(CheckVersion()) router.Use(func(c *gin.Context) { auth := c.GetHeader(client.Authorization) assert.Equal(auth, bearer+" "+token) assert.Equal(c.GetHeader(client.EdgeVersion), edgeVersion) c.Status(200) }) server := httptest.NewServer(router) headers := make(map[string]string) headers[client.Authorization] = fmt.Sprintf("%s %s", client.BearerToken, token) headers[client.EdgeVersion] = edgeVersion tokenclient := client.NewTokenClient(headers) res, err := tokenclient.Get(server.URL) assert.NoError(err) assert.Equal(res.StatusCode, http.StatusOK) } func TestCheckVersionMatch(t *testing.T) { assert := assertAPI.New(t) edgeVersion := serverVersion router := gin.Default() router.Use(CheckVersion()) router.Use(func(c *gin.Context) { auth := c.GetHeader(client.Authorization) assert.Equal(auth, bearer+" "+token) assert.Equal(c.GetHeader(client.EdgeVersion), edgeVersion) c.Status(200) }) server := httptest.NewServer(router) headers := make(map[string]string) headers[client.Authorization] = fmt.Sprintf("%s %s", client.BearerToken, token) headers[client.EdgeVersion] = edgeVersion tokenclient := client.NewTokenClient(headers) res, err := tokenclient.Get(server.URL) assert.NoError(err) assert.Equal(res.StatusCode, http.StatusOK) }