...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/rules/server/rules_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/rules/server

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  
    12  	rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
    13  
    14  	"github.com/gin-gonic/gin"
    15  	"github.com/google/shlex"
    16  	"github.com/stretchr/testify/assert"
    17  )
    18  
    19  type MockRulesEngine struct {
    20  	AddedNames map[string][]string
    21  	Conflict   bool
    22  	RulesEngine
    23  }
    24  
    25  func NewMockReng() MockRulesEngine {
    26  	return MockRulesEngine{AddedNames: map[string][]string{
    27  		"commands": {},
    28  		"privs":    {},
    29  		"rules":    {},
    30  	}}
    31  }
    32  
    33  func (MockRulesEngine) GetEARolesForCommand(_ context.Context, _ rulesengine.Command, _ string) ([]string, error) {
    34  	return []string{}, nil
    35  }
    36  
    37  func (MockRulesEngine) UserHasRoles(_ string, _ []string, _ []string) bool {
    38  	return false
    39  }
    40  
    41  func (mreng *MockRulesEngine) ReadCommands(_ context.Context) ([]rulesengine.Command, error) {
    42  	lst := []rulesengine.Command{}
    43  
    44  	for _, name := range mreng.AddedNames["commands"] {
    45  		lst = append(lst, rulesengine.Command{ID: "test", Name: name})
    46  	}
    47  	return lst, nil
    48  }
    49  
    50  func (mreng *MockRulesEngine) ReadPrivileges(_ context.Context) ([]rulesengine.Privilege, error) {
    51  	lst := []rulesengine.Privilege{}
    52  
    53  	for _, name := range mreng.AddedNames["privs"] {
    54  		lst = append(lst, rulesengine.Privilege{ID: "test", Name: name})
    55  	}
    56  	return lst, nil
    57  }
    58  
    59  func (mreng *MockRulesEngine) ReadCommand(_ context.Context, name string) (rulesengine.Command, error) {
    60  	for _, namein := range mreng.AddedNames["commands"] {
    61  		if name == namein {
    62  			return rulesengine.Command{Name: name, ID: "test"}, nil
    63  		}
    64  	}
    65  	return rulesengine.Command{}, nil
    66  }
    67  
    68  func (mreng MockRulesEngine) ReadPrivilege(_ context.Context, name string) (rulesengine.Privilege, error) {
    69  	for _, namein := range mreng.AddedNames["privs"] {
    70  		if name == namein {
    71  			return rulesengine.Privilege{Name: name, ID: "test"}, nil
    72  		}
    73  	}
    74  	return rulesengine.Privilege{}, nil
    75  }
    76  
    77  func (mreng *MockRulesEngine) ReadAllDefaultRules(_ context.Context) ([]rulesengine.Rule, error) {
    78  	res := []rulesengine.Rule{}
    79  	for _, rule := range mreng.AddedNames["rules"] {
    80  		vals, err := shlex.Split(rule)
    81  		if err != nil {
    82  			return res, err
    83  		}
    84  		res = append(res, rulesengine.Rule{
    85  			Command:    rulesengine.Command{Name: vals[0], ID: "test"},
    86  			Privileges: []rulesengine.Privilege{{Name: vals[1], ID: "test"}},
    87  		})
    88  	}
    89  	return res, nil
    90  }
    91  func (mreng *MockRulesEngine) ReadDefaultRulesForCommand(ctx context.Context, _ string) ([]rulesengine.Rule, error) {
    92  	return mreng.ReadAllDefaultRules(ctx)
    93  }
    94  
    95  func getTestGinContext(r *httptest.ResponseRecorder) (*gin.Context, *gin.Engine) {
    96  	gin.SetMode(gin.TestMode)
    97  	ctx, ginEngine := gin.CreateTestContext(r)
    98  	return ctx, ginEngine
    99  }
   100  
   101  type postDefaultRulesMock struct {
   102  	RulesEngine
   103  
   104  	dataRet rulesengine.AddRuleResult
   105  	errRet  error
   106  
   107  	callCount int
   108  	rules     rulesengine.WriteRules
   109  }
   110  
   111  func (pdrb *postDefaultRulesMock) AddDefaultRules(_ context.Context, rules rulesengine.WriteRules) (rulesengine.AddRuleResult, error) {
   112  	pdrb.callCount = pdrb.callCount + 1
   113  	pdrb.rules = rules
   114  	return pdrb.dataRet, pdrb.errRet
   115  }
   116  
   117  func TestPostDefaultRules(t *testing.T) {
   118  	t.Parallel()
   119  
   120  	tests := map[string]struct {
   121  		reqBody string
   122  
   123  		mockDataRet rulesengine.AddRuleResult
   124  		mockErrRet  error
   125  
   126  		expMockCalledCount int
   127  		expCalledRules     rulesengine.WriteRules
   128  		expCode            int
   129  
   130  		jsonAssert StringAssertionFunc
   131  	}{
   132  		"Ok": {
   133  			reqBody: `[
   134  				{"command": "ls", "privileges": ["read","write"]},
   135  				{"command": "cat", "privileges": ["read","write"]}
   136  			]`,
   137  
   138  			expMockCalledCount: 1,
   139  			expCalledRules: rulesengine.WriteRules{
   140  				{Command: "ls", Privileges: []string{"read", "write"}},
   141  				{Command: "cat", Privileges: []string{"read", "write"}},
   142  			},
   143  
   144  			expCode:    http.StatusOK,
   145  			jsonAssert: JSONEmpty(),
   146  		},
   147  		"Invalid JSON": {
   148  			reqBody: `[{"comm`,
   149  
   150  			expMockCalledCount: 0,
   151  
   152  			expCode:    http.StatusBadRequest,
   153  			jsonAssert: JSONEmpty(),
   154  		},
   155  		"Invalid payload": {
   156  			reqBody: `[{"command": "", "privileges": ["", ""]}]`,
   157  
   158  			expMockCalledCount: 0,
   159  
   160  			expCode:    http.StatusBadRequest,
   161  			jsonAssert: JSONEmpty(),
   162  		},
   163  
   164  		"Rulesengine Error": {
   165  			reqBody: `[
   166  				{"command": "ls", "privileges": ["read","write"]},
   167  				{"command": "cat", "privileges": ["read","write"]}
   168  			]`,
   169  
   170  			mockDataRet: rulesengine.AddRuleResult{},
   171  			mockErrRet:  fmt.Errorf("an error occurred"),
   172  
   173  			expMockCalledCount: 1,
   174  			expCalledRules: rulesengine.WriteRules{
   175  				{Command: "ls", Privileges: []string{"read", "write"}},
   176  				{Command: "cat", Privileges: []string{"read", "write"}},
   177  			},
   178  
   179  			expCode:    http.StatusInternalServerError,
   180  			jsonAssert: JSONEmpty(),
   181  		},
   182  		"Rulesengine Conflict": {
   183  			reqBody: `[
   184  				{"command": "not-here", "privileges": ["read","write"]},
   185  				{"command": "cat", "privileges": ["read","not-here"]}
   186  			]`,
   187  
   188  			mockDataRet: rulesengine.AddRuleResult{Errors: []rulesengine.Error{
   189  				{Command: "not-here", Type: rulesengine.UnknownCommand},
   190  				{Privilege: "not-here", Type: rulesengine.UnknownPrivilege},
   191  			}},
   192  			mockErrRet: nil,
   193  
   194  			expMockCalledCount: 1,
   195  			expCalledRules: rulesengine.WriteRules{
   196  				{Command: "not-here", Privileges: []string{"read", "write"}},
   197  				{Command: "cat", Privileges: []string{"read", "not-here"}},
   198  			},
   199  
   200  			expCode: http.StatusNotFound,
   201  			jsonAssert: JSONEq(`{
   202  				"errors": [
   203  					{"command":"not-here","type":"Unknown Command"},
   204  					{"privilege":"not-here","type":"Unknown Privilege"}
   205  				]
   206  			}`),
   207  		},
   208  	}
   209  
   210  	for name, tc := range tests {
   211  		tc := tc
   212  		t.Run(name, func(t *testing.T) {
   213  			t.Parallel()
   214  
   215  			ruleseng := postDefaultRulesMock{
   216  				dataRet: tc.mockDataRet,
   217  				errRet:  tc.mockErrRet,
   218  			}
   219  
   220  			log := newLogger()
   221  
   222  			r := httptest.NewRecorder()
   223  			_, ginEngine := getTestGinContext(r)
   224  			_, err := New(ginEngine, &ruleseng, log)
   225  			assert.NoError(t, err)
   226  
   227  			req, err := http.NewRequest(http.MethodPost, "/admin/rules/default/commands", strings.NewReader(tc.reqBody))
   228  			assert.NoError(t, err)
   229  
   230  			ginEngine.ServeHTTP(r, req)
   231  
   232  			assert.Equal(t, tc.expCode, r.Result().StatusCode)
   233  
   234  			assert.Equal(t, tc.expMockCalledCount, ruleseng.callCount)
   235  			assert.Equal(t, tc.expCalledRules, ruleseng.rules)
   236  
   237  			tc.jsonAssert(t, r.Body.String())
   238  		})
   239  	}
   240  }
   241  
   242  // nolint:dupl
   243  func TestReadAllDefaultRules(t *testing.T) {
   244  	log := newLogger()
   245  	t.Setenv("RCLI_RES_DATA_DIR", "./testdata")
   246  	ruleseng := MockRulesEngine{AddedNames: map[string][]string{"rules": {"ls basic"}}}
   247  
   248  	r := httptest.NewRecorder()
   249  	_, ginEngine := getTestGinContext(r)
   250  	_, err := New(ginEngine, &ruleseng, log)
   251  	assert.Nil(t, err)
   252  
   253  	req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands", nil)
   254  	assert.NoError(t, err)
   255  	ginEngine.ServeHTTP(r, req)
   256  	response := r.Result()
   257  
   258  	assert.Equal(t, response.StatusCode, http.StatusOK)
   259  	var respData []rulesengine.Rule
   260  	err = json.Unmarshal(r.Body.Bytes(), &respData)
   261  	assert.NoError(t, err)
   262  	assert.Equal(t, []rulesengine.Rule{{Command: rulesengine.Command{Name: "ls", ID: "test"}, Privileges: []rulesengine.Privilege{{Name: "basic", ID: "test"}}}}, respData)
   263  }
   264  
   265  // nolint:dupl
   266  func TestReadDefaultRuleForCommand(t *testing.T) {
   267  	log := newLogger()
   268  	t.Setenv("RCLI_RES_DATA_DIR", "./testdata")
   269  	ruleseng := MockRulesEngine{AddedNames: map[string][]string{"rules": {"ls basic"}}}
   270  
   271  	r := httptest.NewRecorder()
   272  	_, ginEngine := getTestGinContext(r)
   273  	_, err := New(ginEngine, &ruleseng, log)
   274  	assert.Nil(t, err)
   275  
   276  	req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands/ls", nil)
   277  	assert.NoError(t, err)
   278  	ginEngine.ServeHTTP(r, req)
   279  	response := r.Result()
   280  
   281  	assert.Equal(t, response.StatusCode, http.StatusOK)
   282  	var respData rulesengine.Rule
   283  	err = json.Unmarshal(r.Body.Bytes(), &respData)
   284  	assert.NoError(t, err)
   285  	assert.Equal(t, rulesengine.Rule{Command: rulesengine.Command{Name: "ls", ID: "test"}, Privileges: []rulesengine.Privilege{{Name: "basic", ID: "test"}}}, respData)
   286  }
   287  
   288  // nolint:dupl
   289  func TestReadDefaultRulesNoRules(t *testing.T) {
   290  	log := newLogger()
   291  	t.Setenv("RCLI_RES_DATA_DIR", "./testdata")
   292  	ruleseng := MockRulesEngine{}
   293  
   294  	r := httptest.NewRecorder()
   295  	_, ginEngine := getTestGinContext(r)
   296  	_, err := New(ginEngine, &ruleseng, log)
   297  	assert.Nil(t, err)
   298  
   299  	req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands", nil)
   300  	assert.NoError(t, err)
   301  	ginEngine.ServeHTTP(r, req)
   302  	response := r.Result()
   303  
   304  	assert.Equal(t, response.StatusCode, http.StatusOK)
   305  	var respData []rulesengine.Rule
   306  	err = json.Unmarshal(r.Body.Bytes(), &respData)
   307  	assert.NoError(t, err)
   308  	assert.Equal(t, []rulesengine.Rule(nil), respData)
   309  }
   310  
   311  var (
   312  	retVal = rulesengine.RuleWithOverrides{
   313  		Command: rulesengine.Command{
   314  			Name: "testCommand",
   315  			ID:   "testCommandID",
   316  		},
   317  		Banners: []rulesengine.BannerPrivOverrides{{
   318  			Banner: rulesengine.Banner{
   319  				BannerName: "testBannerName",
   320  				BannerID:   "testBannerID",
   321  			},
   322  			Privileges: []rulesengine.Privilege{
   323  				{
   324  					Name: "testPriv1",
   325  					ID:   "testPrivID1",
   326  				},
   327  			},
   328  		}},
   329  		Default: rulesengine.DefaultRule{Privileges: []rulesengine.Privilege{{
   330  			Name: "testPriv2",
   331  			ID:   "testPrivID2",
   332  		}}},
   333  	}
   334  
   335  	retValString = `{
   336  		"command": {
   337  			"id": "testCommandID",
   338  			"name": "testCommand"
   339  		},
   340  		"default": {
   341  			"privileges": [
   342  				{
   343  					"id": "testPrivID2",
   344  					"name": "testPriv2"
   345  				}
   346  			]
   347  		},
   348  		"banners": [
   349  			{
   350  				"banner": {
   351  					"id": "testBannerID",
   352  					"name": "testBannerName"
   353  				},
   354  				"privileges": [
   355  					{
   356  						"id": "testPrivID1",
   357  						"name": "testPriv1"
   358  					}
   359  				]
   360  			}
   361  		]
   362  	}`
   363  	retvalNoBanners = rulesengine.RuleWithOverrides{
   364  		Command: rulesengine.Command{
   365  			Name: "testCommand",
   366  			ID:   "testCommandID",
   367  		},
   368  		Default: rulesengine.DefaultRule{
   369  			Privileges: []rulesengine.Privilege{{
   370  				Name: "testPriv2",
   371  				ID:   "testPrivID2",
   372  			}},
   373  		},
   374  		Banners: []rulesengine.BannerPrivOverrides{},
   375  	}
   376  	retValStringNoBanners = `{
   377  		"command": {
   378  			"id": "testCommandID",
   379  			"name": "testCommand"
   380  		},
   381  		"default": {
   382  			"privileges": [
   383  				{
   384  					"id": "testPrivID2",
   385  					"name": "testPriv2"
   386  				}
   387  			]
   388  		},
   389  		"banners": []
   390  	}`
   391  	retValNoDefaults = rulesengine.RuleWithOverrides{
   392  		Command: rulesengine.Command{
   393  			Name: "testCommand",
   394  			ID:   "testCommandID",
   395  		},
   396  		Banners: []rulesengine.BannerPrivOverrides{{
   397  			Banner: rulesengine.Banner{
   398  				BannerName: "testBannerName",
   399  				BannerID:   "testBannerID",
   400  			},
   401  			Privileges: []rulesengine.Privilege{
   402  				{
   403  					Name: "testPriv1",
   404  					ID:   "testPrivID1",
   405  				},
   406  			},
   407  		}},
   408  		Default: rulesengine.DefaultRule{},
   409  	}
   410  	retValNoDefaultsString = `{
   411  		"command": {
   412  			"id": "testCommandID",
   413  			"name": "testCommand"
   414  		},
   415  		"default": {},
   416  		"banners": [
   417  			{
   418  				"banner": {
   419  					"id": "testBannerID",
   420  					"name": "testBannerName"
   421  				},
   422  				"privileges": [
   423  					{
   424  						"id": "testPrivID1",
   425  						"name": "testPriv1"
   426  					}
   427  				]
   428  			}
   429  		]
   430  	}`
   431  
   432  	retValCommandOnly = rulesengine.RuleWithOverrides{
   433  		Command: rulesengine.Command{
   434  			Name: "testCommand",
   435  			ID:   "testCommandID",
   436  		},
   437  		Banners: []rulesengine.BannerPrivOverrides{},
   438  		Default: rulesengine.DefaultRule{},
   439  	}
   440  	retValCommandOnlyString = `{
   441  		"command": {
   442  			"id": "testCommandID",
   443  			"name": "testCommand"
   444  		},
   445  		"default": {},
   446  		"banners": []
   447  	}`
   448  )
   449  
   450  type getAllRulesMock struct {
   451  	rulesengine.RulesEngine
   452  	retVal rulesengine.RuleWithOverrides
   453  	retErr error
   454  }
   455  
   456  func (gm getAllRulesMock) ReadAllRulesForCommand(_ context.Context, _ string) (rulesengine.RuleWithOverrides, error) {
   457  	return gm.retVal, gm.retErr
   458  }
   459  func TestReadAllRulesForCommand(t *testing.T) {
   460  	t.Parallel()
   461  	tests := map[string]struct {
   462  		// request details
   463  		url string
   464  		// defined rules engine that we will be using
   465  		mreng getAllRulesMock
   466  
   467  		// expected results from http
   468  		expStatus int
   469  		expOut    string
   470  	}{
   471  		"Nominal": {
   472  			url: "/admin/rules/commands/testCommand",
   473  			mreng: getAllRulesMock{
   474  				retVal: retVal,
   475  				retErr: nil,
   476  			},
   477  			expStatus: 200,
   478  			expOut:    retValString,
   479  		},
   480  		"Internal Server Error": {
   481  			url: "/admin/rules/commands/testCommand",
   482  			mreng: getAllRulesMock{
   483  				retVal: rulesengine.RuleWithOverrides{},
   484  				retErr: fmt.Errorf("something went wrong"),
   485  			},
   486  			expStatus: 500,
   487  		},
   488  		"Command Not Listed": {
   489  			url: "/admin/rules/commands/testCommand",
   490  			mreng: getAllRulesMock{
   491  				retVal: rulesengine.RuleWithOverrides{},
   492  				retErr: nil,
   493  			},
   494  			expStatus: 200,
   495  			expOut:    "null",
   496  		},
   497  		"No Banners": {
   498  			url: "/admin/rules/commands/testCommand",
   499  			mreng: getAllRulesMock{
   500  				retVal: retvalNoBanners,
   501  				retErr: nil,
   502  			},
   503  			expStatus: 200,
   504  			expOut:    retValStringNoBanners,
   505  		},
   506  		"No Default Privileges": {
   507  			url: "/admin/rules/commands/testCommand",
   508  			mreng: getAllRulesMock{
   509  				retErr: nil,
   510  				retVal: retValNoDefaults,
   511  			},
   512  			expStatus: 200,
   513  			expOut:    retValNoDefaultsString,
   514  		},
   515  		"No Rules": {
   516  			url: "/admin/rules/commands/testCommand",
   517  			mreng: getAllRulesMock{
   518  				retErr: nil,
   519  				retVal: retValCommandOnly,
   520  			},
   521  			expStatus: 200,
   522  			expOut:    retValCommandOnlyString,
   523  		},
   524  	}
   525  	for name, tc := range tests {
   526  		tc := tc
   527  		t.Run(name, func(t *testing.T) {
   528  			t.Parallel()
   529  			log := newLogger()
   530  
   531  			r := httptest.NewRecorder()
   532  			_, ginEngine := getTestGinContext(r)
   533  			_, err := New(ginEngine, &tc.mreng, log)
   534  			assert.NoError(t, err)
   535  
   536  			req, err := http.NewRequest(http.MethodGet, tc.url, nil)
   537  			assert.NoError(t, err)
   538  
   539  			ginEngine.ServeHTTP(r, req)
   540  
   541  			assert.Equal(t, tc.expStatus, r.Result().StatusCode)
   542  			if tc.expStatus == 200 && r.Result().StatusCode == 200 {
   543  				assert.JSONEq(t, tc.expOut, r.Body.String())
   544  			}
   545  		})
   546  	}
   547  }
   548  

View as plain text