...

Source file src/github.com/Microsoft/hcsshim/pkg/securitypolicy/securitypolicyenforcer_rego.go

Documentation: github.com/Microsoft/hcsshim/pkg/securitypolicy

     1  //go:build linux && rego
     2  // +build linux,rego
     3  
     4  package securitypolicy
     5  
     6  import (
     7  	"context"
     8  	_ "embed"
     9  	"encoding/base64"
    10  	"encoding/json"
    11  	"fmt"
    12  	"os"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"syscall"
    17  
    18  	"github.com/Microsoft/hcsshim/internal/guest/spec"
    19  	"github.com/Microsoft/hcsshim/internal/guestpath"
    20  	"github.com/Microsoft/hcsshim/internal/log"
    21  	rpi "github.com/Microsoft/hcsshim/internal/regopolicyinterpreter"
    22  	"github.com/opencontainers/runc/libcontainer/user"
    23  	oci "github.com/opencontainers/runtime-spec/specs-go"
    24  	"github.com/pkg/errors"
    25  )
    26  
    27  const regoEnforcerName = "rego"
    28  
    29  func init() {
    30  	registeredEnforcers[regoEnforcerName] = createRegoEnforcer
    31  	// Overriding the value inside init guarantees that this assignment happens
    32  	// after the variable has been initialized in securitypolicy.go and there
    33  	// are no race conditions. When multiple init functions are defined in a
    34  	// single package, the order of their execution is determined by the
    35  	// filename.
    36  	defaultEnforcer = regoEnforcerName
    37  	defaultMarshaller = regoMarshaller
    38  }
    39  
    40  const capabilitiesNilError = "capabilities object provided by the UVM to the policy engine is nil"
    41  const invalidPolicyMessage = "Security policy is not valid. Please check security policy or re-generate with tooling."
    42  const noReasonMessage = "Security policy is either not valid or did not provide a reason for denial. Please check security policy or re-generate with tooling."
    43  const noAPIVersionError = "policy does not define api_version"
    44  
    45  // RegoEnforcer is a stub implementation of a security policy, which will be
    46  // based on [Rego] policy language. The detailed implementation will be
    47  // introduced in the subsequent PRs and documentation updated accordingly.
    48  //
    49  // [Rego]: https://www.openpolicyagent.org/docs/latest/policy-language/
    50  type regoEnforcer struct {
    51  	// Base64 encoded (JSON) policy
    52  	base64policy string
    53  	// Rego interpreter
    54  	rego *rpi.RegoPolicyInterpreter
    55  	// Default mount data
    56  	defaultMounts []oci.Mount
    57  	// Stdio allowed state on a per container id basis
    58  	stdio map[string]bool
    59  	// Maximum error message length
    60  	maxErrorMessageLength int
    61  }
    62  
    63  var _ SecurityPolicyEnforcer = (*regoEnforcer)(nil)
    64  
    65  //nolint:unused
    66  func (sp SecurityPolicy) toInternal() (*securityPolicyInternal, error) {
    67  	policy := new(securityPolicyInternal)
    68  	var err error
    69  	if policy.Containers, err = sp.Containers.toInternal(); err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	return policy, nil
    74  }
    75  
    76  func toStringSet(items []string) stringSet {
    77  	s := make(stringSet)
    78  	for _, item := range items {
    79  		s.add(item)
    80  	}
    81  
    82  	return s
    83  }
    84  
    85  func (s stringSet) toArray() []string {
    86  	a := make([]string, 0, len(s))
    87  	for item := range s {
    88  		a = append(a, item)
    89  	}
    90  
    91  	return a
    92  }
    93  
    94  func (a stringSet) intersect(b stringSet) stringSet {
    95  	s := make(stringSet)
    96  	for item := range a {
    97  		if b.contains(item) {
    98  			s.add(item)
    99  		}
   100  	}
   101  
   102  	return s
   103  }
   104  
   105  type inputData map[string]interface{}
   106  
   107  func createRegoEnforcer(base64EncodedPolicy string,
   108  	defaultMounts []oci.Mount,
   109  	privilegedMounts []oci.Mount,
   110  	maxErrorMessageLength int,
   111  ) (SecurityPolicyEnforcer, error) {
   112  	// base64 decode the incoming policy string
   113  	// It will either be (legacy) JSON or Rego.
   114  	rawPolicy, err := base64.StdEncoding.DecodeString(base64EncodedPolicy)
   115  	if err != nil {
   116  		return nil, fmt.Errorf("unable to decode policy from Base64 format: %w", err)
   117  	}
   118  
   119  	// Try to unmarshal the JSON
   120  	var code string
   121  	securityPolicy := new(SecurityPolicy)
   122  	err = json.Unmarshal(rawPolicy, securityPolicy)
   123  	if err == nil {
   124  		if securityPolicy.AllowAll {
   125  			return createOpenDoorEnforcer(base64EncodedPolicy, defaultMounts, privilegedMounts, maxErrorMessageLength)
   126  		}
   127  
   128  		containers := make([]*Container, securityPolicy.Containers.Length)
   129  
   130  		for i := 0; i < securityPolicy.Containers.Length; i++ {
   131  			index := strconv.Itoa(i)
   132  			cConf, ok := securityPolicy.Containers.Elements[index]
   133  			if !ok {
   134  				return nil, fmt.Errorf("container constraint with index %q not found", index)
   135  			}
   136  			cConf.AllowStdioAccess = true
   137  			cConf.NoNewPrivileges = false
   138  			cConf.User = UserConfig{
   139  				UserIDName:   IDNameConfig{Strategy: IDNameStrategyAny},
   140  				GroupIDNames: []IDNameConfig{{Strategy: IDNameStrategyAny}},
   141  				Umask:        "0022",
   142  			}
   143  			cConf.SeccompProfileSHA256 = ""
   144  			containers[i] = &cConf
   145  		}
   146  
   147  		code, err = marshalRego(
   148  			securityPolicy.AllowAll,
   149  			containers,
   150  			[]ExternalProcessConfig{},
   151  			[]FragmentConfig{},
   152  			true,
   153  			true,
   154  			true,
   155  			false,
   156  			true,
   157  			false,
   158  		)
   159  		if err != nil {
   160  			return nil, fmt.Errorf("error marshaling the policy to Rego: %w", err)
   161  		}
   162  	} else {
   163  		// this is either a Rego policy or malformed JSON
   164  		code = string(rawPolicy)
   165  	}
   166  
   167  	regoPolicy, err := newRegoPolicy(code, defaultMounts, privilegedMounts)
   168  	if err != nil {
   169  		return nil, fmt.Errorf("error creating Rego policy: %w", err)
   170  	}
   171  	regoPolicy.base64policy = base64EncodedPolicy
   172  	regoPolicy.maxErrorMessageLength = maxErrorMessageLength
   173  	return regoPolicy, nil
   174  }
   175  
   176  func (policy *regoEnforcer) enableLogging(path string, logLevel rpi.LogLevel) {
   177  	policy.rego.EnableLogging(path, logLevel)
   178  }
   179  
   180  func newRegoPolicy(code string, defaultMounts []oci.Mount, privilegedMounts []oci.Mount) (policy *regoEnforcer, err error) {
   181  	policy = new(regoEnforcer)
   182  
   183  	policy.defaultMounts = make([]oci.Mount, len(defaultMounts))
   184  	copy(policy.defaultMounts, defaultMounts)
   185  
   186  	defaultMountData := make([]interface{}, 0, len(defaultMounts))
   187  	privilegedMountData := make([]interface{}, 0, len(privilegedMounts))
   188  	data := map[string]interface{}{
   189  		"defaultMounts":                   appendMountData(defaultMountData, defaultMounts),
   190  		"privilegedMounts":                appendMountData(privilegedMountData, privilegedMounts),
   191  		"sandboxPrefix":                   guestpath.SandboxMountPrefix,
   192  		"hugePagesPrefix":                 guestpath.HugePagesMountPrefix,
   193  		"plan9Prefix":                     plan9Prefix,
   194  		"defaultUnprivilegedCapabilities": DefaultUnprivilegedCapabilities(),
   195  		"defaultPrivilegedCapabilities":   DefaultPrivilegedCapabilities(),
   196  	}
   197  
   198  	policy.rego, err = rpi.NewRegoPolicyInterpreter(code, data)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  	policy.stdio = map[string]bool{}
   203  
   204  	policy.base64policy = ""
   205  	policy.rego.AddModule("framework.rego", &rpi.RegoModule{Namespace: "framework", Code: FrameworkCode})
   206  	policy.rego.AddModule("api.rego", &rpi.RegoModule{Namespace: "api", Code: APICode})
   207  
   208  	err = policy.rego.Compile()
   209  	if err != nil {
   210  		return nil, fmt.Errorf("rego compilation failed: %w", err)
   211  	}
   212  
   213  	// by default we do not perform message truncation
   214  	policy.maxErrorMessageLength = 0
   215  
   216  	return policy, nil
   217  }
   218  
   219  func (policy *regoEnforcer) applyDefaults(enforcementPoint string, results rpi.RegoQueryResult) (rpi.RegoQueryResult, error) {
   220  	deny := rpi.RegoQueryResult{"allowed": false}
   221  	info, err := policy.queryEnforcementPoint(enforcementPoint)
   222  	if err != nil {
   223  		return deny, err
   224  	}
   225  
   226  	if results.IsEmpty() && info.availableByPolicyVersion {
   227  		// policy should define this rule but it is missing
   228  		return deny, fmt.Errorf("rule for %s is missing from policy", enforcementPoint)
   229  	}
   230  
   231  	return info.defaultResults.Union(results), nil
   232  }
   233  
   234  type enforcementPointInfo struct {
   235  	availableByPolicyVersion bool
   236  	defaultResults           rpi.RegoQueryResult
   237  }
   238  
   239  func (policy *regoEnforcer) queryEnforcementPoint(enforcementPoint string) (*enforcementPointInfo, error) {
   240  	input := inputData{
   241  		"name": enforcementPoint,
   242  		"rule": enforcementPoint,
   243  	}
   244  	result, err := policy.rego.Query("data.framework.enforcement_point_info", input)
   245  
   246  	if err != nil {
   247  		return nil, fmt.Errorf("error querying enforcement point information: %w", err)
   248  	}
   249  
   250  	unknown, err := result.Bool("unknown")
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	if unknown {
   256  		return nil, fmt.Errorf("enforcement point rule %s does not exist", enforcementPoint)
   257  	}
   258  
   259  	invalid, err := result.Bool("invalid")
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  
   264  	if invalid {
   265  		return nil, fmt.Errorf("enforcement point rule %s is invalid", enforcementPoint)
   266  	}
   267  
   268  	versionMissing, err := result.Bool("version_missing")
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  
   273  	if versionMissing {
   274  		return nil, errors.New(noAPIVersionError)
   275  	}
   276  
   277  	defaultResults, err := result.Object("default_results")
   278  	if err != nil {
   279  		return nil, errors.New("enforcement point result missing defaults")
   280  	}
   281  
   282  	availableByPolicyVersion, err := result.Bool("available")
   283  	if err != nil {
   284  		return nil, errors.New("enforcement point result missing availability info")
   285  	}
   286  
   287  	return &enforcementPointInfo{
   288  		availableByPolicyVersion: availableByPolicyVersion,
   289  		defaultResults:           defaultResults,
   290  	}, nil
   291  }
   292  
   293  func (policy *regoEnforcer) enforce(ctx context.Context, enforcementPoint string, input inputData) (rpi.RegoQueryResult, error) {
   294  	rule := "data.policy." + enforcementPoint
   295  	result, err := policy.rego.Query(rule, input)
   296  	if err != nil {
   297  		return nil, policy.denyWithError(ctx, err, input)
   298  	}
   299  
   300  	result, err = policy.applyDefaults(enforcementPoint, result)
   301  	if err != nil {
   302  		return result, policy.denyWithError(ctx, err, input)
   303  	}
   304  
   305  	allowed, err := result.Bool("allowed")
   306  	if err != nil {
   307  		return nil, policy.denyWithError(ctx, err, input)
   308  	}
   309  
   310  	if !allowed {
   311  		return nil, policy.denyWithReason(ctx, enforcementPoint, input)
   312  	}
   313  
   314  	return result, nil
   315  }
   316  
   317  type decisionTruncator func(map[string]interface{})
   318  
   319  func truncateErrorObjects(decision map[string]interface{}) {
   320  	if rawReason, ok := decision["reason"]; ok {
   321  		// check if it is a framework reason object
   322  		if reason, ok := rawReason.(rpi.RegoQueryResult); ok {
   323  			// check if we can remove error_objects
   324  			if _, ok := reason["error_objects"]; ok {
   325  				decision["truncated"] = append(decision["truncated"].([]string), "reason.error_objects")
   326  				delete(reason, "error_objects")
   327  				decision["reason"] = reason
   328  			}
   329  		}
   330  	}
   331  }
   332  
   333  func truncateInput(decision map[string]interface{}) {
   334  	if _, ok := decision["input"]; ok {
   335  		// remove the input
   336  		decision["truncated"] = append(decision["truncated"].([]string), "input")
   337  		delete(decision, "input")
   338  	}
   339  }
   340  
   341  func truncateReason(decision map[string]interface{}) {
   342  	decision["truncated"] = append(decision["truncated"].([]string), "reason")
   343  	delete(decision, "reason")
   344  }
   345  
   346  func (policy *regoEnforcer) policyDecisionToError(ctx context.Context, decision map[string]interface{}) error {
   347  	decisionJSON, err := json.Marshal(decision)
   348  	if err != nil {
   349  		log.G(ctx).WithError(err).Error("unable to marshal error object")
   350  		decisionJSON = []byte(`"Unable to marshal error object"`)
   351  	}
   352  
   353  	log.G(ctx).WithField("policyDecision", string(decisionJSON))
   354  
   355  	base64EncodedDecisionJSON := base64.StdEncoding.EncodeToString(decisionJSON)
   356  	errorMessage := fmt.Errorf(policyDecisionPattern, base64EncodedDecisionJSON)
   357  	if policy.maxErrorMessageLength == 0 {
   358  		// indicates no message truncation
   359  		return fmt.Errorf(policyDecisionPattern, base64EncodedDecisionJSON)
   360  	}
   361  
   362  	if len(errorMessage.Error()) <= policy.maxErrorMessageLength {
   363  		return errorMessage
   364  	}
   365  
   366  	decision["truncated"] = []string{}
   367  	truncators := []decisionTruncator{truncateErrorObjects, truncateInput, truncateReason}
   368  	for _, truncate := range truncators {
   369  		truncate(decision)
   370  
   371  		decisionJSON, err := json.Marshal(decision)
   372  		if err != nil {
   373  			log.G(ctx).WithError(err).Error("unable to marshal error object")
   374  			decisionJSON = []byte(`"Unable to marshal error object"`)
   375  		}
   376  		base64EncodedDecisionJSON = base64.StdEncoding.EncodeToString(decisionJSON)
   377  		errorMessage = fmt.Errorf(policyDecisionPattern, base64EncodedDecisionJSON)
   378  
   379  		if len(errorMessage.Error()) <= policy.maxErrorMessageLength {
   380  			break
   381  		}
   382  	}
   383  
   384  	return errorMessage
   385  }
   386  
   387  func (policy *regoEnforcer) denyWithError(ctx context.Context, policyError error, input inputData) error {
   388  	input = policy.redactSensitiveData(input)
   389  	input = replaceCapabilitiesWithPlaceholders(input)
   390  	policyDecision := map[string]interface{}{
   391  		"input":       input,
   392  		"decision":    "deny",
   393  		"reason":      invalidPolicyMessage,
   394  		"policyError": policyError.Error(),
   395  	}
   396  
   397  	return policy.policyDecisionToError(ctx, policyDecision)
   398  }
   399  
   400  func (policy *regoEnforcer) denyWithReason(ctx context.Context, enforcementPoint string, input inputData) error {
   401  	cleaned_input := policy.redactSensitiveData(input)
   402  	cleaned_input = replaceCapabilitiesWithPlaceholders(cleaned_input)
   403  	input["rule"] = enforcementPoint
   404  	policyDecision := map[string]interface{}{
   405  		"input":    cleaned_input,
   406  		"decision": "deny",
   407  	}
   408  
   409  	result, err := policy.rego.Query("data.policy.reason", input)
   410  	if err == nil {
   411  		if result.IsEmpty() {
   412  			policyDecision["reason"] = noReasonMessage
   413  		} else {
   414  			policyDecision["reason"] = replaceCapabilitiesWithPlaceholdersInReason(result)
   415  		}
   416  	} else {
   417  		log.G(ctx).WithError(err).Warn("unable to obtain reason for policy decision")
   418  		policyDecision["reason"] = noReasonMessage
   419  	}
   420  
   421  	return policy.policyDecisionToError(ctx, policyDecision)
   422  }
   423  
   424  func areCapsEqual(actual map[string]interface{}, expected map[string][]string) bool {
   425  	for key, caps := range expected {
   426  		values, ok := actual[key].([]interface{})
   427  		if !ok {
   428  			return false
   429  		}
   430  
   431  		if len(values) != len(caps) {
   432  			return false
   433  		}
   434  
   435  		for i, value := range values {
   436  			cap, ok := value.(string)
   437  			if !ok {
   438  				return false
   439  			}
   440  
   441  			if cap != caps[i] {
   442  				return false
   443  			}
   444  		}
   445  	}
   446  
   447  	return true
   448  }
   449  
   450  var privilegedCapabilities = map[string][]string{
   451  	"bounding":    DefaultPrivilegedCapabilities(),
   452  	"effective":   DefaultPrivilegedCapabilities(),
   453  	"inheritable": DefaultPrivilegedCapabilities(),
   454  	"permitted":   DefaultPrivilegedCapabilities(),
   455  	"ambient":     EmptyCapabiltiesSet(),
   456  }
   457  
   458  var unprivilegedCapabilities = map[string][]string{
   459  	"bounding":    DefaultUnprivilegedCapabilities(),
   460  	"effective":   DefaultUnprivilegedCapabilities(),
   461  	"inheritable": EmptyCapabiltiesSet(),
   462  	"permitted":   DefaultUnprivilegedCapabilities(),
   463  	"ambient":     EmptyCapabiltiesSet(),
   464  }
   465  
   466  // as capability lists are repetitive and take up a lot of room in the error
   467  // message, we can replace the defaults with placeholders to save space
   468  func replaceCapabilitiesWithPlaceholders(object map[string]interface{}) map[string]interface{} {
   469  	capabilities, ok := object["capabilities"].(map[string]interface{})
   470  	if !ok {
   471  		return object
   472  	}
   473  
   474  	if areCapsEqual(capabilities, privilegedCapabilities) {
   475  		object["capabilities"] = "[privileged]"
   476  	} else if areCapsEqual(capabilities, unprivilegedCapabilities) {
   477  		object["capabilities"] = "[unprivileged]"
   478  	}
   479  
   480  	return object
   481  }
   482  
   483  func replaceCapabilitiesWithPlaceholdersInReason(reason rpi.RegoQueryResult) rpi.RegoQueryResult {
   484  	errorObjectsRaw, err := reason.Value("error_objects")
   485  	if err != nil {
   486  		return reason
   487  	}
   488  
   489  	errorObjects, ok := errorObjectsRaw.([]interface{})
   490  	if !ok {
   491  		return reason
   492  	}
   493  
   494  	objects := make([]interface{}, len(errorObjects))
   495  	for i, objectRaw := range errorObjects {
   496  		object, ok := objectRaw.(map[string]interface{})
   497  		if !ok {
   498  			objects[i] = objectRaw
   499  			continue
   500  		}
   501  
   502  		objects[i] = replaceCapabilitiesWithPlaceholders(object)
   503  	}
   504  
   505  	reason["error_objects"] = objects
   506  	return reason
   507  }
   508  
   509  func (policy *regoEnforcer) redactSensitiveData(input inputData) inputData {
   510  	if v, k := input["envList"]; k {
   511  		newInput := make(inputData)
   512  		for k, v := range input {
   513  			newInput[k] = v
   514  		}
   515  
   516  		newEnvList := make([]string, 0)
   517  		cast, ok := v.([]string)
   518  		if ok {
   519  			for _, env := range cast {
   520  				parts := strings.Split(env, "=")
   521  				redacted := parts[0] + "=<<redacted>>"
   522  				newEnvList = append(newEnvList, redacted)
   523  			}
   524  		}
   525  
   526  		newInput["envList"] = newEnvList
   527  
   528  		return newInput
   529  	}
   530  
   531  	return input
   532  }
   533  
   534  func (policy *regoEnforcer) EnforceDeviceMountPolicy(ctx context.Context, target string, deviceHash string) error {
   535  	input := inputData{
   536  		"target":     target,
   537  		"deviceHash": deviceHash,
   538  	}
   539  
   540  	_, err := policy.enforce(ctx, "mount_device", input)
   541  	return err
   542  }
   543  
   544  func (policy *regoEnforcer) EnforceOverlayMountPolicy(ctx context.Context, containerID string, layerPaths []string, target string) error {
   545  	input := inputData{
   546  		"containerID": containerID,
   547  		"layerPaths":  layerPaths,
   548  		"target":      target,
   549  	}
   550  
   551  	_, err := policy.enforce(ctx, "mount_overlay", input)
   552  	return err
   553  }
   554  
   555  func (policy *regoEnforcer) EnforceOverlayUnmountPolicy(ctx context.Context, target string) error {
   556  	input := inputData{
   557  		"unmountTarget": target,
   558  	}
   559  
   560  	_, err := policy.enforce(ctx, "unmount_overlay", input)
   561  	return err
   562  }
   563  
   564  func getEnvsToKeep(envList []string, results rpi.RegoQueryResult) ([]string, error) {
   565  	value, err := results.Value("env_list")
   566  	if err != nil || value == nil {
   567  		// policy did not return an 'env_list'. This is interpreted
   568  		// as "proceed with provided env list".
   569  		return envList, nil
   570  	}
   571  
   572  	envsAsInterfaces, ok := value.([]interface{})
   573  
   574  	if !ok {
   575  		return nil, fmt.Errorf("policy returned incorrect type for 'env_list', expected []interface{}, received %T", value)
   576  	}
   577  
   578  	keepSet := make(stringSet)
   579  	for _, envAsInterface := range envsAsInterfaces {
   580  		if env, ok := envAsInterface.(string); ok {
   581  			keepSet.add(env)
   582  		} else {
   583  			return nil, fmt.Errorf("members of env_list from policy must be strings, received %T", envAsInterface)
   584  		}
   585  	}
   586  
   587  	keepSet = keepSet.intersect(toStringSet(envList))
   588  	return keepSet.toArray(), nil
   589  }
   590  
   591  func getCapsToKeep(capsList *oci.LinuxCapabilities, results rpi.RegoQueryResult) (*oci.LinuxCapabilities, error) {
   592  	value, err := results.Value("caps_list")
   593  	if err != nil || value == nil {
   594  		// policy did not return an 'caps_list'. This is interpreted
   595  		// as "proceed with provided caps list".
   596  		return capsList, nil
   597  	}
   598  
   599  	capsMap, ok := value.(map[string]interface{})
   600  
   601  	if !ok {
   602  		return nil, fmt.Errorf("policy returned incorrect type for 'caps_list', expected map[string]interface{}, received %T", value)
   603  	}
   604  
   605  	bounding, err := filterCapabilities(capsList.Bounding, capsMap["bounding"])
   606  	if err != nil {
   607  		return nil, err
   608  	}
   609  	effective, err := filterCapabilities(capsList.Effective, capsMap["effective"])
   610  	if err != nil {
   611  		return nil, err
   612  	}
   613  	inheritable, err := filterCapabilities(capsList.Inheritable, capsMap["inheritable"])
   614  	if err != nil {
   615  		return nil, err
   616  	}
   617  	permitted, err := filterCapabilities(capsList.Permitted, capsMap["permitted"])
   618  	if err != nil {
   619  		return nil, err
   620  	}
   621  	ambient, err := filterCapabilities(capsList.Ambient, capsMap["ambient"])
   622  	if err != nil {
   623  		return nil, err
   624  	}
   625  
   626  	return &oci.LinuxCapabilities{
   627  		Bounding:    bounding,
   628  		Effective:   effective,
   629  		Inheritable: inheritable,
   630  		Permitted:   permitted,
   631  		Ambient:     ambient,
   632  	}, nil
   633  }
   634  
   635  func filterCapabilities(suppliedList []string, fromRegoCapsList interface{}) ([]string, error) {
   636  	keepSet := make(stringSet)
   637  	if capsList, ok := fromRegoCapsList.([]interface{}); ok {
   638  		for _, capAsInterface := range capsList {
   639  			if cap, ok := capAsInterface.(string); ok {
   640  				keepSet.add(cap)
   641  			} else {
   642  				return nil, fmt.Errorf("members of capability sets from policy must be strings, received %T", capAsInterface)
   643  			}
   644  		}
   645  	} else {
   646  		return nil, fmt.Errorf("capability sets of caps_list from policy must be an array of interface{}, received %T", fromRegoCapsList)
   647  	}
   648  
   649  	keepSet = keepSet.intersect(toStringSet(suppliedList))
   650  	return keepSet.toArray(), nil
   651  }
   652  
   653  func (idName IDName) toInput() interface{} {
   654  	return map[string]interface{}{
   655  		"id":   idName.ID,
   656  		"name": idName.Name,
   657  	}
   658  }
   659  
   660  func groupsToInputs(groups []IDName) []interface{} {
   661  	inputs := []interface{}{}
   662  	for _, group := range groups {
   663  		inputs = append(inputs, group.toInput())
   664  	}
   665  	return inputs
   666  }
   667  
   668  func handleNilOrEmptyCaps(caps []string) interface{} {
   669  	if len(caps) > 0 {
   670  		result := make([]interface{}, len(caps))
   671  		for i, cap := range caps {
   672  			result[i] = cap
   673  		}
   674  
   675  		return result
   676  	}
   677  
   678  	// caps is either nil or empty.
   679  	// In either case, we want to return an empty array.
   680  	return make([]interface{}, 0)
   681  }
   682  
   683  func mapifyCapabilities(caps *oci.LinuxCapabilities) map[string]interface{} {
   684  	out := make(map[string]interface{})
   685  
   686  	out["bounding"] = handleNilOrEmptyCaps(caps.Bounding)
   687  	out["effective"] = handleNilOrEmptyCaps(caps.Effective)
   688  	out["inheritable"] = handleNilOrEmptyCaps(caps.Inheritable)
   689  	out["permitted"] = handleNilOrEmptyCaps(caps.Permitted)
   690  	out["ambient"] = handleNilOrEmptyCaps(caps.Ambient)
   691  	return out
   692  }
   693  
   694  func (policy *regoEnforcer) EnforceCreateContainerPolicy(
   695  	ctx context.Context,
   696  	sandboxID string,
   697  	containerID string,
   698  	argList []string,
   699  	envList []string,
   700  	workingDir string,
   701  	mounts []oci.Mount,
   702  	privileged bool,
   703  	noNewPrivileges bool,
   704  	user IDName,
   705  	groups []IDName,
   706  	umask string,
   707  	capabilities *oci.LinuxCapabilities,
   708  	seccompProfileSHA256 string,
   709  ) (envToKeep EnvList,
   710  	capsToKeep *oci.LinuxCapabilities,
   711  	stdioAccessAllowed bool,
   712  	err error) {
   713  	if capabilities == nil {
   714  		return nil, nil, false, errors.New(capabilitiesNilError)
   715  	}
   716  
   717  	input := inputData{
   718  		"containerID":          containerID,
   719  		"argList":              argList,
   720  		"envList":              envList,
   721  		"workingDir":           workingDir,
   722  		"sandboxDir":           spec.SandboxMountsDir(sandboxID),
   723  		"hugePagesDir":         spec.HugePagesMountsDir(sandboxID),
   724  		"mounts":               appendMountData([]interface{}{}, mounts),
   725  		"privileged":           privileged,
   726  		"noNewPrivileges":      noNewPrivileges,
   727  		"user":                 user.toInput(),
   728  		"groups":               groupsToInputs(groups),
   729  		"umask":                umask,
   730  		"capabilities":         mapifyCapabilities(capabilities),
   731  		"seccompProfileSHA256": seccompProfileSHA256,
   732  	}
   733  
   734  	results, err := policy.enforce(ctx, "create_container", input)
   735  	if err != nil {
   736  		return nil, nil, false, err
   737  	}
   738  
   739  	envToKeep, err = getEnvsToKeep(envList, results)
   740  	if err != nil {
   741  		return nil, nil, false, err
   742  	}
   743  
   744  	capsToKeep, err = getCapsToKeep(capabilities, results)
   745  	if err != nil {
   746  		return nil, nil, false, err
   747  	}
   748  
   749  	stdioAccessAllowed, err = results.Bool("allow_stdio_access")
   750  	if err != nil {
   751  		return nil, nil, false, err
   752  	}
   753  
   754  	// Store the result of stdio access allowed for this container so we can use
   755  	// it if we get queried about allowing exec in container access. Stdio access
   756  	// is on a per-container, not per-process basis.
   757  	policy.stdio[containerID] = stdioAccessAllowed
   758  
   759  	return envToKeep, capsToKeep, stdioAccessAllowed, nil
   760  }
   761  
   762  func (policy *regoEnforcer) EnforceDeviceUnmountPolicy(ctx context.Context, unmountTarget string) error {
   763  	input := inputData{
   764  		"unmountTarget": unmountTarget,
   765  	}
   766  
   767  	_, err := policy.enforce(ctx, "unmount_device", input)
   768  	return err
   769  }
   770  
   771  func appendMountData(mountData []interface{}, mounts []oci.Mount) []interface{} {
   772  	for _, mount := range mounts {
   773  		mountData = append(mountData, inputData{
   774  			"destination": mount.Destination,
   775  			"source":      mount.Source,
   776  			"options":     mount.Options,
   777  			"type":        mount.Type,
   778  		})
   779  	}
   780  
   781  	return mountData
   782  }
   783  
   784  func (policy *regoEnforcer) ExtendDefaultMounts(mounts []oci.Mount) error {
   785  	policy.defaultMounts = append(policy.defaultMounts, mounts...)
   786  	defaultMounts := appendMountData([]interface{}{}, policy.defaultMounts)
   787  	return policy.rego.UpdateData("defaultMounts", defaultMounts)
   788  }
   789  
   790  func (policy *regoEnforcer) EncodedSecurityPolicy() string {
   791  	return policy.base64policy
   792  }
   793  
   794  func (policy *regoEnforcer) EnforceExecInContainerPolicy(
   795  	ctx context.Context,
   796  	containerID string,
   797  	argList []string,
   798  	envList []string,
   799  	workingDir string,
   800  	noNewPrivileges bool,
   801  	user IDName,
   802  	groups []IDName,
   803  	umask string,
   804  	capabilities *oci.LinuxCapabilities,
   805  ) (envToKeep EnvList,
   806  	capsToKeep *oci.LinuxCapabilities,
   807  	stdioAccessAllowed bool,
   808  	err error) {
   809  	if capabilities == nil {
   810  		return nil, nil, false, errors.New(capabilitiesNilError)
   811  	}
   812  
   813  	input := inputData{
   814  		"containerID":     containerID,
   815  		"argList":         argList,
   816  		"envList":         envList,
   817  		"workingDir":      workingDir,
   818  		"noNewPrivileges": noNewPrivileges,
   819  		"user":            user.toInput(),
   820  		"groups":          groupsToInputs(groups),
   821  		"umask":           umask,
   822  		"capabilities":    mapifyCapabilities(capabilities),
   823  	}
   824  
   825  	results, err := policy.enforce(ctx, "exec_in_container", input)
   826  	if err != nil {
   827  		return nil, nil, false, err
   828  	}
   829  
   830  	envToKeep, err = getEnvsToKeep(envList, results)
   831  	if err != nil {
   832  		return nil, nil, false, err
   833  	}
   834  
   835  	capsToKeep, err = getCapsToKeep(capabilities, results)
   836  	if err != nil {
   837  		return nil, nil, false, err
   838  	}
   839  
   840  	return envToKeep, capsToKeep, policy.stdio[containerID], nil
   841  }
   842  
   843  func (policy *regoEnforcer) EnforceExecExternalProcessPolicy(ctx context.Context, argList []string, envList []string, workingDir string) (toKeep EnvList, stdioAccessAllowed bool, err error) {
   844  	input := map[string]interface{}{
   845  		"argList":    argList,
   846  		"envList":    envList,
   847  		"workingDir": workingDir,
   848  	}
   849  
   850  	results, err := policy.enforce(ctx, "exec_external", input)
   851  	if err != nil {
   852  		return nil, false, err
   853  	}
   854  
   855  	toKeep, err = getEnvsToKeep(envList, results)
   856  	if err != nil {
   857  		return nil, false, err
   858  	}
   859  
   860  	stdioAccessAllowed, err = results.Bool("allow_stdio_access")
   861  	if err != nil {
   862  		return nil, false, err
   863  	}
   864  
   865  	return toKeep, stdioAccessAllowed, nil
   866  }
   867  
   868  func (policy *regoEnforcer) EnforceShutdownContainerPolicy(ctx context.Context, containerID string) error {
   869  	input := inputData{
   870  		"containerID": containerID,
   871  	}
   872  
   873  	_, err := policy.enforce(ctx, "shutdown_container", input)
   874  	return err
   875  }
   876  
   877  func (policy *regoEnforcer) EnforceSignalContainerProcessPolicy(ctx context.Context, containerID string, signal syscall.Signal, isInitProcess bool, startupArgList []string) error {
   878  	input := inputData{
   879  		"containerID":   containerID,
   880  		"signal":        signal,
   881  		"isInitProcess": isInitProcess,
   882  		"argList":       startupArgList,
   883  	}
   884  
   885  	_, err := policy.enforce(ctx, "signal_container_process", input)
   886  	return err
   887  }
   888  
   889  func (policy *regoEnforcer) EnforcePlan9MountPolicy(ctx context.Context, target string) error {
   890  	mountPathPrefix := strings.Replace(guestpath.LCOWMountPathPrefixFmt, "%d", "[0-9]+", 1)
   891  	input := inputData{
   892  		"rootPrefix":      guestpath.LCOWRootPrefixInUVM,
   893  		"mountPathPrefix": mountPathPrefix,
   894  		"target":          target,
   895  	}
   896  
   897  	_, err := policy.enforce(ctx, "plan9_mount", input)
   898  	return err
   899  }
   900  
   901  func (policy *regoEnforcer) EnforcePlan9UnmountPolicy(ctx context.Context, target string) error {
   902  	input := map[string]interface{}{
   903  		"unmountTarget": target,
   904  	}
   905  
   906  	_, err := policy.enforce(ctx, "plan9_unmount", input)
   907  	return err
   908  }
   909  
   910  func (policy *regoEnforcer) EnforceGetPropertiesPolicy(ctx context.Context) error {
   911  	input := make(inputData)
   912  
   913  	_, err := policy.enforce(ctx, "get_properties", input)
   914  	return err
   915  }
   916  
   917  func (policy *regoEnforcer) EnforceDumpStacksPolicy(ctx context.Context) error {
   918  	input := make(inputData)
   919  
   920  	_, err := policy.enforce(ctx, "dump_stacks", input)
   921  	return err
   922  }
   923  
   924  func (policy *regoEnforcer) EnforceRuntimeLoggingPolicy(ctx context.Context) error {
   925  	input := make(inputData)
   926  	_, err := policy.enforce(ctx, "runtime_logging", input)
   927  	return err
   928  }
   929  
   930  func parseNamespace(rego string) (string, error) {
   931  	lines := strings.Split(rego, "\n")
   932  	parts := strings.Split(lines[0], " ")
   933  	if parts[0] != "package" {
   934  		return "", errors.New("package definition required on first line")
   935  	}
   936  
   937  	return strings.TrimSpace(parts[1]), nil
   938  }
   939  
   940  func (policy *regoEnforcer) LoadFragment(ctx context.Context, issuer string, feed string, rego string) error {
   941  	namespace, err := parseNamespace(rego)
   942  	if err != nil {
   943  		return fmt.Errorf("unable to load fragment: %w", err)
   944  	}
   945  
   946  	fragment := &rpi.RegoModule{
   947  		Issuer:    issuer,
   948  		Feed:      feed,
   949  		Code:      rego,
   950  		Namespace: namespace,
   951  	}
   952  
   953  	policy.rego.AddModule(fragment.ID(), fragment)
   954  
   955  	input := inputData{
   956  		"issuer":    issuer,
   957  		"feed":      feed,
   958  		"namespace": namespace,
   959  	}
   960  
   961  	results, err := policy.enforce(ctx, "load_fragment", input)
   962  
   963  	addModule, _ := results.Bool("add_module")
   964  	if !addModule {
   965  		policy.rego.RemoveModule(fragment.ID())
   966  	}
   967  
   968  	return err
   969  }
   970  
   971  func (policy *regoEnforcer) EnforceScratchMountPolicy(ctx context.Context, scratchPath string, encrypted bool) error {
   972  	input := map[string]interface{}{
   973  		"target":    scratchPath,
   974  		"encrypted": encrypted,
   975  	}
   976  	_, err := policy.enforce(ctx, "scratch_mount", input)
   977  	if err != nil {
   978  		return err
   979  	}
   980  	return nil
   981  }
   982  
   983  func (policy *regoEnforcer) EnforceScratchUnmountPolicy(ctx context.Context, scratchPath string) error {
   984  	input := map[string]interface{}{
   985  		"unmountTarget": scratchPath,
   986  	}
   987  	_, err := policy.enforce(ctx, "scratch_unmount", input)
   988  	if err != nil {
   989  		return err
   990  	}
   991  	return nil
   992  }
   993  
   994  func getUser(passwdPath string, filter func(user.User) bool) (user.User, error) {
   995  	users, err := user.ParsePasswdFileFilter(passwdPath, filter)
   996  	if err != nil {
   997  		return user.User{}, err
   998  	}
   999  	if len(users) != 1 {
  1000  		return user.User{}, errors.Errorf("expected exactly 1 user matched '%d'", len(users))
  1001  	}
  1002  	return users[0], nil
  1003  }
  1004  
  1005  func getGroup(groupPath string, filter func(user.Group) bool) (user.Group, error) {
  1006  	groups, err := user.ParseGroupFileFilter(groupPath, filter)
  1007  	if err != nil {
  1008  		return user.Group{}, err
  1009  	}
  1010  	if len(groups) != 1 {
  1011  		return user.Group{}, errors.Errorf("expected exactly 1 group matched '%d'", len(groups))
  1012  	}
  1013  	return groups[0], nil
  1014  }
  1015  
  1016  func (policy *regoEnforcer) GetUserInfo(containerID string, process *oci.Process) (IDName, []IDName, string, error) {
  1017  	rootPath := filepath.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath)
  1018  	passwdPath := filepath.Join(rootPath, "/etc/passwd")
  1019  	groupPath := filepath.Join(rootPath, "/etc/group")
  1020  
  1021  	if process == nil {
  1022  		return IDName{}, nil, "", errors.New("spec.Process is nil")
  1023  	}
  1024  
  1025  	uid := process.User.UID
  1026  	userIDName := IDName{ID: strconv.FormatUint(uint64(uid), 10), Name: ""}
  1027  	if _, err := os.Stat(passwdPath); err == nil {
  1028  		userInfo, err := getUser(passwdPath, func(user user.User) bool {
  1029  			return uint32(user.Uid) == uid
  1030  		})
  1031  
  1032  		if err != nil {
  1033  			return userIDName, nil, "", err
  1034  		}
  1035  
  1036  		userIDName.Name = userInfo.Name
  1037  	}
  1038  
  1039  	gid := process.User.GID
  1040  	groupIDName := IDName{ID: strconv.FormatUint(uint64(gid), 10), Name: ""}
  1041  
  1042  	checkGroup := true
  1043  	if _, err := os.Stat(groupPath); err == nil {
  1044  		groupInfo, err := getGroup(groupPath, func(group user.Group) bool {
  1045  			return uint32(group.Gid) == gid
  1046  		})
  1047  
  1048  		if err != nil {
  1049  			return userIDName, nil, "", err
  1050  		}
  1051  		groupIDName.Name = groupInfo.Name
  1052  	} else {
  1053  		checkGroup = false
  1054  	}
  1055  
  1056  	groupIDNames := []IDName{groupIDName}
  1057  	additionalGIDs := process.User.AdditionalGids
  1058  	if len(additionalGIDs) > 0 {
  1059  		for _, gid := range additionalGIDs {
  1060  			groupIDName = IDName{ID: strconv.FormatUint(uint64(gid), 10), Name: ""}
  1061  			if checkGroup {
  1062  				groupInfo, err := getGroup(groupPath, func(group user.Group) bool {
  1063  					return uint32(group.Gid) == gid
  1064  				})
  1065  				if err != nil {
  1066  					return userIDName, nil, "", err
  1067  				}
  1068  				groupIDName.Name = groupInfo.Name
  1069  			}
  1070  			groupIDNames = append(groupIDNames, groupIDName)
  1071  		}
  1072  	}
  1073  
  1074  	// this default value is used in the Linux kernel if no umask is specified
  1075  	umask := "0022"
  1076  	if process.User.Umask != nil {
  1077  		umask = fmt.Sprintf("%04o", *process.User.Umask)
  1078  	}
  1079  
  1080  	return userIDName, groupIDNames, umask, nil
  1081  }
  1082  

View as plain text