...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/rules.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database

     1  package database
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  
     8  	"edge-infra.dev/pkg/lib/fog"
     9  	rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
    10  	datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql"
    11  )
    12  
    13  func (ds Dataset) AddDefaultRules(ctx context.Context, rules []rulesengine.RuleSegment) (feedback rulesengine.AddRuleResult, err error) {
    14  	tx, err := ds.db.BeginTx(ctx, nil)
    15  	if err != nil {
    16  		return feedback, fmt.Errorf("error beginning transaction: %w", err)
    17  	}
    18  
    19  	log := fog.FromContext(ctx).WithName("dataset")
    20  
    21  	defer func() {
    22  		switch {
    23  		case err != nil:
    24  			if txErr := tx.Rollback(); txErr != nil {
    25  				err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr)
    26  			}
    27  		case len(feedback.Errors) != 0:
    28  			if txErr := tx.Rollback(); txErr != nil {
    29  				err = fmt.Errorf("error rolling back transaction due to errors: %w", txErr)
    30  			}
    31  		default:
    32  			if txErr := tx.Commit(); txErr != nil {
    33  				err = fmt.Errorf("error committing transaction: %w", txErr)
    34  			}
    35  		}
    36  	}()
    37  
    38  	for _, rule := range rules {
    39  		var res sql.Result
    40  		res, err := tx.ExecContext(ctx, datasql.InsertRuleDefault, rule.Command.Name, rule.Privilege.Name)
    41  		if err != nil {
    42  			return feedback, fmt.Errorf("error executing query: %w", err)
    43  		}
    44  
    45  		var rowsAffected int64
    46  		rowsAffected, err = res.RowsAffected()
    47  		if err != nil {
    48  			return feedback, fmt.Errorf("error getting query result: %w", err)
    49  		}
    50  
    51  		// No changes were made to the DB, determine why by finding the ID of each
    52  		// value to enter. If an ID is missing this means the value is invalid and
    53  		// and error should be returned. If no ID is missing this means the rule is
    54  		// already present in the DB
    55  		if rowsAffected == 0 {
    56  			rule, err = populateRuleSegmentIDs(ctx, tx, rule)
    57  			if err != nil {
    58  				return feedback, err
    59  			}
    60  
    61  			missingID := false
    62  
    63  			if rule.Command.ID == "" {
    64  				missingID = true
    65  				feedback.Errors = append(feedback.Errors, rulesengine.Error{
    66  					Command: rule.Command.Name,
    67  					Type:    rulesengine.UnknownCommand,
    68  				})
    69  			}
    70  
    71  			if rule.Privilege.ID == "" {
    72  				missingID = true
    73  				feedback.Errors = append(feedback.Errors, rulesengine.Error{
    74  					Privilege: rule.Privilege.Name,
    75  					Type:      rulesengine.UnknownPrivilege,
    76  				})
    77  			}
    78  
    79  			if missingID {
    80  				// Missing ID is found, move to next segment and rollback at end
    81  				continue
    82  			}
    83  
    84  			log.Info(
    85  				"Default rule already present in database",
    86  				"commandName", rule.Command.Name,
    87  				"privilegeName", rule.Privilege.Name,
    88  			)
    89  		}
    90  	}
    91  	return feedback, err
    92  }
    93  
    94  func populateRuleSegmentIDs(ctx context.Context, tx *sql.Tx, segment rulesengine.RuleSegment) (rulesengine.RuleSegment, error) {
    95  	rows, err := tx.QueryContext(ctx, datasql.GetIDsForRuleSegment, segment.Command.Name, segment.Privilege.Name)
    96  	if err != nil {
    97  		return segment, fmt.Errorf("error finding ID's (command %s, privilege %s): %w", segment.Command.Name, segment.Privilege.Name, err)
    98  	}
    99  	defer rows.Close()
   100  
   101  	for rows.Next() {
   102  		var kind, value string
   103  		err := rows.Scan(&kind, &value)
   104  		if err != nil {
   105  			return segment, fmt.Errorf("error scanning for ID's: %w", err)
   106  		}
   107  
   108  		switch kind {
   109  		case "command":
   110  			segment.Command.ID = value
   111  		case "privilege":
   112  			segment.Privilege.ID = value
   113  		default:
   114  			return segment, fmt.Errorf("unknown kind for segment: %s", value)
   115  		}
   116  	}
   117  
   118  	if err := rows.Err(); err != nil {
   119  		return segment, fmt.Errorf("error finding ID's for segment: %w", err)
   120  	}
   121  
   122  	return segment, nil
   123  }
   124  
   125  func scanDefaultRulesRows(rows *sql.Rows) ([]rulesengine.RuleSegment, error) {
   126  	var res []rulesengine.RuleSegment
   127  	for rows.Next() {
   128  		var privname string
   129  		var privid string
   130  		var comid string
   131  		var comname string
   132  		err := rows.Scan(&comname, &privname, &comid, &privid)
   133  		if err != nil {
   134  			return nil, fmt.Errorf("error in data:scanDefaultRulesRows: %v", err)
   135  		}
   136  		res = append(res, rulesengine.RuleSegment{
   137  			Command: rulesengine.Command{
   138  				Name: comname,
   139  				ID:   comid,
   140  			},
   141  			Privilege: rulesengine.Privilege{
   142  				Name: privname,
   143  				ID:   privid,
   144  			},
   145  		})
   146  	}
   147  	err := rows.Err()
   148  	if err != nil {
   149  		err = fmt.Errorf("error in data:scanDefaultRulesRows on rows.Err: %v", err)
   150  	}
   151  	return res, err
   152  }
   153  
   154  func (ds Dataset) ReadAllDefaultRules(ctx context.Context) ([]rulesengine.RuleSegment, error) {
   155  	rows, err := ds.db.QueryContext(ctx, datasql.SelectAllDefaultRules)
   156  	if err != nil {
   157  		return nil, fmt.Errorf("error in data:ReadAllDefaultRules: %v", err)
   158  	}
   159  	defer rows.Close()
   160  
   161  	return scanDefaultRulesRows(rows)
   162  }
   163  
   164  func assembleRules(ruleSegments []rulesengine.RuleSegment) []rulesengine.Rule {
   165  	rules := []rulesengine.Rule{}
   166  	for _, ruleSegment := range ruleSegments {
   167  		rules = appendRule(rules, ruleSegment)
   168  	}
   169  	return rules
   170  }
   171  
   172  func appendRule(rules []rulesengine.Rule, ruleSegment rulesengine.RuleSegment) []rulesengine.Rule {
   173  	if len(rules) == 0 {
   174  		return []rulesengine.Rule{{Command: ruleSegment.Command, Privileges: []rulesengine.Privilege{ruleSegment.Privilege}}}
   175  	}
   176  	for idx, rule := range rules {
   177  		if rule.Command.ID == ruleSegment.Command.ID {
   178  			// dataset handles duplicate instances for us so we dont need to check here
   179  			rules[idx].Privileges = append(rule.Privileges, ruleSegment.Privilege)
   180  			return rules
   181  		}
   182  	}
   183  	rules = append(rules, rulesengine.Rule{Command: ruleSegment.Command, Privileges: []rulesengine.Privilege{ruleSegment.Privilege}})
   184  	return rules
   185  }
   186  
   187  // ReadDefaultRulesForCommand returns list of BannerRuleSegments.
   188  // ReadDefaultRulesForCommand Will not return the command if it no default rules are associated.
   189  func (ds Dataset) ReadDefaultRulesForCommand(ctx context.Context, commandName string) ([]rulesengine.RuleSegment, error) {
   190  	rows, err := ds.db.QueryContext(ctx, datasql.SelectDefaultRulesByCommandName, commandName)
   191  	if err != nil {
   192  		return []rulesengine.RuleSegment{}, err
   193  	}
   194  	defer rows.Close()
   195  
   196  	return scanDefaultRulesRows(rows)
   197  }
   198  
   199  func (ds Dataset) DeleteDefaultRule(ctx context.Context, commandName, privilegeName string) (res rulesengine.DeleteResult, err error) {
   200  	res, err = ds.deleteValue(ctx, datasql.DeleteDefaultRule, commandName, privilegeName)
   201  	if err != nil {
   202  		return res, fmt.Errorf("error deleting privilege from default rule: %w", err)
   203  	}
   204  	if res.RowsAffected != 0 || len(res.Errors) != 0 {
   205  		// Completed successfully or with a conflict
   206  		return res, nil
   207  	}
   208  
   209  	// No changes were made and no errors occurred, this may be due to an unknown value,
   210  	// or a rule that does not exist, find the cause
   211  
   212  	tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
   213  	if err != nil {
   214  		return res, fmt.Errorf("error creating transaction: %w", err)
   215  	}
   216  	defer func() {
   217  		switch {
   218  		case err != nil:
   219  			if txErr := tx.Rollback(); txErr != nil {
   220  				err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr)
   221  			}
   222  		default:
   223  			if txErr := tx.Commit(); txErr != nil {
   224  				err = fmt.Errorf("error committing transaction: %w", txErr)
   225  			}
   226  		}
   227  	}()
   228  
   229  	segment := rulesengine.RuleSegment{
   230  		Command:   rulesengine.Command{Name: commandName},
   231  		Privilege: rulesengine.Privilege{Name: privilegeName},
   232  	}
   233  	segment, err = populateRuleSegmentIDs(ctx, tx, segment)
   234  	if err != nil {
   235  		return res, fmt.Errorf("error discovering ID's: %w", err)
   236  	}
   237  	res.Errors = checkUnknownDeleteRuleErrs(segment)
   238  
   239  	return res, err
   240  }
   241  
   242  func checkUnknownDeleteRuleErrs(segment rulesengine.RuleSegment) []rulesengine.Error {
   243  	var errs []rulesengine.Error
   244  
   245  	if segment.Command.ID == "" {
   246  		errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownCommand, Command: segment.Command.Name})
   247  	}
   248  	if segment.Privilege.ID == "" {
   249  		errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownPrivilege, Privilege: segment.Privilege.Name})
   250  	}
   251  	if len(errs) == 0 {
   252  		errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownRule, Command: segment.Command.Name, Privilege: segment.Privilege.Name})
   253  	}
   254  
   255  	return errs
   256  }
   257  

View as plain text