...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/middleware/auth_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/middleware

     1  package middleware
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"testing"
    10  
    11  	"github.com/gin-gonic/gin"
    12  	"github.com/stretchr/testify/assert"
    13  
    14  	"edge-infra.dev/pkg/lib/fog"
    15  	"edge-infra.dev/pkg/sds/emergencyaccess/types"
    16  )
    17  
    18  func TestSaveAuth(t *testing.T) {
    19  	t.Parallel()
    20  
    21  	tests := map[string]struct {
    22  		addReqHeaders func(*http.Request)
    23  		user          types.User
    24  		ok            bool
    25  	}{
    26  		"Comma separated headers": {
    27  			// Comma separated headers are interpreted as a single value
    28  			addReqHeaders: func(r *http.Request) {
    29  				r.Header.Add("X-Auth-Username", "abcd")
    30  				r.Header.Set("X-Auth-Email", "abcd@efgh.xyz")
    31  				r.Header.Set("X-Auth-Roles", "role1,role2, role3")
    32  				r.Header.Set("X-Auth-Banners", "banner1,banner2")
    33  			},
    34  			user: types.User{
    35  				Username: "abcd",
    36  				Email:    "abcd@efgh.xyz",
    37  				Roles:    []string{"role1,role2, role3"},
    38  				Banners:  []string{"banner1,banner2"},
    39  			},
    40  			ok: true,
    41  		},
    42  		"Multiple headers": {
    43  			// Multiple role and banner headers are interpreted as multiple
    44  			// entries in the slice.
    45  			// Only the first username and email header is read
    46  			addReqHeaders: func(r *http.Request) {
    47  				r.Header.Add("X-Auth-Username", "abcd")
    48  				r.Header.Add("X-Auth-Username", "efgh")
    49  				r.Header.Set("X-Auth-Email", "abcd@efgh.xyz")
    50  				r.Header.Add("X-Auth-Email", "efgh@ijkl.xyz")
    51  				r.Header.Set("X-Auth-Roles", "role1")
    52  				r.Header.Add("X-Auth-Roles", "role2")
    53  				r.Header.Add("X-Auth-Banners", "banner1")
    54  				r.Header.Add("X-Auth-Banners", "banner2")
    55  			},
    56  			user: types.User{
    57  				Username: "abcd",
    58  				Email:    "abcd@efgh.xyz",
    59  				Roles:    []string{"role1", "role2"},
    60  				Banners:  []string{"banner1", "banner2"},
    61  			},
    62  			ok: true,
    63  		},
    64  		"Missing Headers": {
    65  			// Missing headers stores an empty User in the context
    66  			addReqHeaders: func(_ *http.Request) {},
    67  			user:          types.User{},
    68  			ok:            true,
    69  		},
    70  		"Only Some Headers": {
    71  			// Currently if only some headers have been set in the incoming
    72  			// request we do a best effort of setting up the User and saving it
    73  			// to context
    74  			addReqHeaders: func(r *http.Request) {
    75  				r.Header.Add("X-Auth-Username", "abcd")
    76  				r.Header.Set("X-Auth-Roles", "role1")
    77  			},
    78  			user: types.User{
    79  				Username: "abcd",
    80  				Email:    "",
    81  				Roles:    []string{"role1"},
    82  				Banners:  nil,
    83  			},
    84  			ok: true,
    85  		},
    86  	}
    87  
    88  	for name, tc := range tests {
    89  		tc := tc
    90  		t.Run(name, func(t *testing.T) {
    91  			t.Parallel()
    92  
    93  			r := httptest.NewRecorder()
    94  			req := httptest.NewRequest(http.MethodGet, "/", nil)
    95  
    96  			tc.addReqHeaders(req)
    97  
    98  			_, router := gin.CreateTestContext(r)
    99  			router.ContextWithFallback = true
   100  			router.Use(SaveAuthToContext())
   101  
   102  			var user types.User
   103  			var ok bool
   104  			router.Any("/", func(ctx *gin.Context) {
   105  				user, ok = types.UserFromContext(ctx)
   106  			})
   107  
   108  			router.ServeHTTP(r, req)
   109  
   110  			assert.Equal(t, tc.user, user)
   111  			assert.Equal(t, tc.ok, ok)
   112  		})
   113  	}
   114  }
   115  
   116  func TestVerifyUserDetailsInContext(t *testing.T) {
   117  	t.Parallel()
   118  
   119  	r := httptest.NewRecorder()
   120  	_, router := gin.CreateTestContext(r)
   121  	router.ContextWithFallback = true
   122  	router.Use(VerifyUserDetailsInContext())
   123  
   124  	user := types.User{
   125  		Username: "username",
   126  		Email:    "email@domain",
   127  	}
   128  	ctx := types.UserIntoContext(context.Background(), user)
   129  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
   130  	assert.NoError(t, err)
   131  
   132  	router.Any("/", func(ctx *gin.Context) {
   133  		ctx.Status(http.StatusOK)
   134  	})
   135  	router.ServeHTTP(r, req)
   136  
   137  	assert.Equal(t, http.StatusOK, r.Result().StatusCode)
   138  }
   139  
   140  func TestVerifyUserDetailsInContextFail(t *testing.T) {
   141  	t.Parallel()
   142  
   143  	tests := map[string]struct {
   144  		user   types.User
   145  		expMsg string
   146  	}{
   147  		"No username": {
   148  			user: types.User{
   149  				Email: "email@domain",
   150  			},
   151  			expMsg: "60001: User Authorization Failure - User not found. Error: username not found in authservice request",
   152  		},
   153  		"No email": {
   154  			user: types.User{
   155  				Username: "username",
   156  			},
   157  			expMsg: "60001: User Authorization Failure - User not found. Error: email not found in authservice request",
   158  		},
   159  		"Neither username nor email": {
   160  			expMsg: "60001: User Authorization Failure - User not found. Error: username not found in authservice request\nemail not found in authservice request",
   161  		},
   162  	}
   163  
   164  	for name, tc := range tests {
   165  		tc := tc
   166  		t.Run(name, func(t *testing.T) {
   167  			t.Parallel()
   168  
   169  			r := httptest.NewRecorder()
   170  			_, router := gin.CreateTestContext(r)
   171  			router.ContextWithFallback = true
   172  			router.Use(VerifyUserDetailsInContext())
   173  
   174  			buf := bytes.Buffer{}
   175  			ctx := fog.IntoContext(context.Background(), fog.New(fog.To(&buf)))
   176  			ctx = types.UserIntoContext(ctx, tc.user)
   177  			req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
   178  			assert.NoError(t, err)
   179  
   180  			router.Any("/", func(ctx *gin.Context) {
   181  				ctx.Status(http.StatusOK)
   182  			})
   183  
   184  			router.ServeHTTP(r, req)
   185  
   186  			m := map[string]interface{}{}
   187  			err = json.Unmarshal(buf.Bytes(), &m)
   188  			assert.NoError(t, err)
   189  
   190  			assert.Equal(t, http.StatusForbidden, r.Result().StatusCode)
   191  			assert.Equal(t, tc.expMsg, m["error"])
   192  		})
   193  	}
   194  }
   195  

View as plain text