...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/ea_integration/v2/helpers.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/ea_integration/v2

     1  package integrationv2
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"testing"
    12  
    13  	"cloud.google.com/go/pubsub"
    14  	"sigs.k8s.io/kustomize/kyaml/kio"
    15  
    16  	"edge-infra.dev/pkg/k8s/testing/kmp"
    17  	"edge-infra.dev/test/f2"
    18  	"edge-infra.dev/test/f2/x/ktest"
    19  	"edge-infra.dev/test/f2/x/postgres"
    20  
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  const (
    26  	httpDomain = "http://"
    27  )
    28  
    29  type rulesEnginePortForwardKey struct{}
    30  
    31  type urlElems []string
    32  
    33  var (
    34  	addCommandsURLElems     = urlElems{"admin", "commands"}
    35  	addPrivilegesURLElems   = urlElems{"admin", "privileges"}
    36  	addDefaultRulesURLElems = urlElems{"admin", "rules", "default", "commands"}
    37  )
    38  
    39  // waits on deployment defined by the manifest
    40  func waitOn(tManifest string) f2.StepFn {
    41  	return func(ctx f2.Context, t *testing.T) f2.Context {
    42  		k := ktest.FromContextT(ctx, t)
    43  
    44  		manifests, err := processManifests(tManifest, k.Namespace)
    45  		assert.NoError(t, err)
    46  
    47  		for _, manifest := range manifests {
    48  			if manifest.GetKind() != "Deployment" {
    49  				continue
    50  			}
    51  			k.WaitOn(t, k.Check(manifest, kmp.IsCurrent()))
    52  		}
    53  		return ctx
    54  	}
    55  }
    56  
    57  // deploys a service to an f2 cluster using ktest
    58  func createService(tManifest string, filters ...kio.Filter) f2.StepFn {
    59  	return func(ctx f2.Context, t *testing.T) f2.Context {
    60  		k := ktest.FromContextT(ctx, t)
    61  		manifests, err := processManifests(tManifest, k.Namespace, filters...)
    62  		assert.NoError(t, err)
    63  
    64  		for _, manifest := range manifests {
    65  			if manifest.GetKind() == "Namespace" {
    66  				continue
    67  			}
    68  			err = k.Client.Create(ctx, manifest)
    69  			assert.NoError(t, err)
    70  		}
    71  
    72  		return ctx
    73  	}
    74  }
    75  
    76  type addNamePayload struct {
    77  	Name string `json:"name"`
    78  }
    79  
    80  func addNames(ctx f2.Context, t *testing.T, names []string, urlElems urlElems) (*http.Response, error) {
    81  	var payload []addNamePayload
    82  	for _, name := range names {
    83  		payload = append(payload, addNamePayload{name})
    84  	}
    85  	data, err := json.Marshal(payload)
    86  	if err != nil {
    87  		return nil, fmt.Errorf("failed to unmarshal payload: %w", err)
    88  	}
    89  	requestURL, err := createURLFromPortForward(ctx, t, urlElems)
    90  	if err != nil {
    91  		return nil, fmt.Errorf("failed to create request URL: %w", err)
    92  	}
    93  	resp, err := sendPostRequest(ctx, t, data, requestURL.String())
    94  	if err != nil {
    95  		return nil, fmt.Errorf("failed to send request: %w", err)
    96  	}
    97  	return resp, nil
    98  }
    99  
   100  func addCommands(commands []string) f2.StepFn {
   101  	return func(ctx f2.Context, t *testing.T) f2.Context {
   102  		resp, err := addNames(ctx, t, commands, addCommandsURLElems)
   103  		require.NoError(t, err, "addCommands failed")
   104  		require.Equal(t, http.StatusOK, resp.StatusCode, "addCommands API returned non-ok status: %d", resp.StatusCode)
   105  		return ctx
   106  	}
   107  }
   108  
   109  func addPrivileges(privileges []string) f2.StepFn {
   110  	return func(ctx f2.Context, t *testing.T) f2.Context {
   111  		resp, err := addNames(ctx, t, privileges, addPrivilegesURLElems)
   112  		require.NoError(t, err, "addPrivileges failed")
   113  		require.Equal(t, http.StatusOK, resp.StatusCode, "addPrivileges API returned non-ok status: %d", resp.StatusCode)
   114  		return ctx
   115  	}
   116  }
   117  
   118  type defaultRule struct {
   119  	command string
   120  	privs   []string
   121  }
   122  
   123  func addDefaultRules(rules []defaultRule) f2.StepFn {
   124  	type Payload struct {
   125  		Command    string   `json:"command"`
   126  		Privileges []string `json:"privileges"`
   127  	}
   128  	return func(ctx f2.Context, t *testing.T) f2.Context {
   129  		var payload []Payload
   130  		for _, rule := range rules {
   131  			payload = append(payload, Payload{rule.command, rule.privs})
   132  		}
   133  		data, err := json.Marshal(payload)
   134  		require.NoError(t, err, "failed to unmarshal AddDefaultRules payload")
   135  
   136  		requestURL, err := createURLFromPortForward(ctx, t, addDefaultRulesURLElems)
   137  		require.NoError(t, err, "failed to create valid AddDefaultRules request url")
   138  
   139  		resp, err := sendPostRequest(ctx, t, data, requestURL.String())
   140  		require.NoError(t, err, "AddDefaultRules request fail")
   141  		require.Equal(t, http.StatusOK, resp.StatusCode, "AddDefaultRules API returned non-ok status: %d", resp.StatusCode)
   142  
   143  		return ctx
   144  	}
   145  }
   146  
   147  func deleteDefaultRule(command, privilege string) f2.StepFn {
   148  	return func(ctx f2.Context, t *testing.T) f2.Context {
   149  		portForward, ok := ctx.Context.Value(rulesEnginePortForwardKey{}).(ktest.PortForward)
   150  		require.True(t, ok, "failed to retrieve portForward struct from ctx")
   151  		domain := httpDomain + portForward.Retrieve(t)
   152  		requestURL, err := url.JoinPath(domain, "admin", "rules", "default", "commands", command, "privileges", privilege)
   153  		require.NoError(t, err, "failed to create valid DeleteDefaultRule request url")
   154  
   155  		resp, err := sendPostRequest(ctx, t, nil, requestURL)
   156  		require.NoError(t, err, "DeleteDefaultRule request fail")
   157  		require.Equal(t, http.StatusOK, resp.StatusCode, "DeleteDefaultRule API returned non-ok status: %d", resp.StatusCode)
   158  
   159  		return ctx
   160  	}
   161  }
   162  
   163  func sendPostRequest(ctx f2.Context, t *testing.T, data []byte, requestURL string) (*http.Response, error) {
   164  	resp, err := sendRequest(ctx, t, data, requestURL, "POST")
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	_, err = readResponseBody(t, resp)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	return resp, err
   173  }
   174  
   175  // this is here in order to close the body as well as printing for debugging purposes.
   176  func readResponseBody(t *testing.T, resp *http.Response) ([]byte, error) {
   177  	bytes, err := io.ReadAll(resp.Body)
   178  	if err != nil {
   179  		return nil, fmt.Errorf("failed to read response body: %w", err)
   180  	}
   181  	if len(bytes) != 0 {
   182  		t.Logf("Response body: %s", string(bytes))
   183  	}
   184  	defer resp.Body.Close()
   185  	return bytes, nil
   186  }
   187  
   188  func sendRequest(ctx f2.Context, t *testing.T, data []byte, requestURL string, method string) (*http.Response, error) {
   189  	req, err := http.NewRequestWithContext(ctx, method, requestURL, bytes.NewBuffer(data))
   190  	if err != nil {
   191  		return nil, fmt.Errorf("failed to create request: %w", err)
   192  	}
   193  
   194  	t.Logf("Invoking RCLI services: %s", requestURL)
   195  	resp, err := http.DefaultClient.Do(req)
   196  	if err != nil {
   197  		return nil, fmt.Errorf("failed to send request")
   198  	}
   199  	return resp, nil
   200  }
   201  
   202  func createURLFromPortForward(ctx f2.Context, t *testing.T, urlElems urlElems) (*url.URL, error) {
   203  	portForward, ok := ctx.Context.Value(rulesEnginePortForwardKey{}).(ktest.PortForward)
   204  	if !ok {
   205  		return nil, errors.New("failed to retrieve portForward struct from ctx")
   206  	}
   207  	domain, err := url.Parse(httpDomain + portForward.Retrieve(t))
   208  	if err != nil {
   209  		return nil, fmt.Errorf("failed to parse url domain: %w", err)
   210  	}
   211  	requestURL := domain.JoinPath(urlElems...)
   212  	return requestURL, nil
   213  }
   214  
   215  // Takes a map of BslRole to human readable eaprivs (slice).
   216  // Uses the postgres extension to connect to the database and run the appropriate query to
   217  // insert the data in the table.
   218  // Will fail on conflict (duplicate value in table).
   219  // Will fail on missing privilege in ea_rules_privileges.
   220  func addRoleMapping(roleEaMap map[string][]string) f2.StepFn {
   221  	return func(ctx f2.Context, t *testing.T) f2.Context {
   222  		// insert the privs into oi_role_privileges
   223  		// we could optimise this by using a concatenated query instead but since this is a test db it isn't necessary.
   224  		pg := postgres.FromContextT(ctx, t)
   225  		db := pg.DB()
   226  		for bslRole, eaPrivs := range roleEaMap {
   227  			for _, priv := range eaPrivs {
   228  				_, err := db.ExecContext(ctx, InsertIntoEaRulesPrivileges, bslRole, priv)
   229  				require.NoError(t, err, "issue when inserting roles into ea_roles_privileges (role: %s, privilege: %s)", bslRole, priv)
   230  			}
   231  		}
   232  		return ctx
   233  	}
   234  }
   235  
   236  const (
   237  	// InsertIntoEaRulesPrivileges throws error on no privilegeID found (insert NULL error)
   238  	InsertIntoEaRulesPrivileges = `
   239  	INSERT INTO oi_role_privileges (role_name,privilege_id)
   240  		select $1, privilege_id from
   241  			unnest(ARRAY[$2]) as privileges(name)
   242  		left join ea_rules_privileges
   243  			on privileges.name = ea_rules_privileges.name
   244  		;`
   245  )
   246  
   247  func assertPubSubMessageAttributesEqual(t *testing.T, msg *pubsub.Message, expectedAttributes map[string]string) {
   248  	for key, expectedValue := range expectedAttributes {
   249  		actualValue, exists := msg.Attributes[key]
   250  		assert.True(t, exists, "expected attribute %s to exist", key)
   251  		assert.Equal(t, expectedValue, actualValue, "expected attribute %s to have value %s, but got %s", key, expectedValue, actualValue)
   252  	}
   253  }
   254  
   255  func assertPubSubMessageAttributesNotNil(t *testing.T, msg *pubsub.Message, expectedAttributes []string) {
   256  	for _, key := range expectedAttributes {
   257  		value, exists := msg.Attributes[key]
   258  		assert.True(t, exists, "expected attribute %s to exist", key)
   259  		assert.NotNil(t, value, "expected attribute %s to be not nil", key)
   260  	}
   261  }
   262  

View as plain text