package database import ( "context" "database/sql" "fmt" "edge-infra.dev/pkg/lib/fog" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql" ) func (ds Dataset) AddDefaultRules(ctx context.Context, rules []rulesengine.RuleSegment) (feedback rulesengine.AddRuleResult, err error) { tx, err := ds.db.BeginTx(ctx, nil) if err != nil { return feedback, fmt.Errorf("error beginning transaction: %w", err) } log := fog.FromContext(ctx).WithName("dataset") defer func() { switch { case err != nil: if txErr := tx.Rollback(); txErr != nil { err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr) } case len(feedback.Errors) != 0: if txErr := tx.Rollback(); txErr != nil { err = fmt.Errorf("error rolling back transaction due to errors: %w", txErr) } default: if txErr := tx.Commit(); txErr != nil { err = fmt.Errorf("error committing transaction: %w", txErr) } } }() for _, rule := range rules { var res sql.Result res, err := tx.ExecContext(ctx, datasql.InsertRuleDefault, rule.Command.Name, rule.Privilege.Name) if err != nil { return feedback, fmt.Errorf("error executing query: %w", err) } var rowsAffected int64 rowsAffected, err = res.RowsAffected() if err != nil { return feedback, fmt.Errorf("error getting query result: %w", err) } // No changes were made to the DB, determine why by finding the ID of each // value to enter. If an ID is missing this means the value is invalid and // and error should be returned. If no ID is missing this means the rule is // already present in the DB if rowsAffected == 0 { rule, err = populateRuleSegmentIDs(ctx, tx, rule) if err != nil { return feedback, err } missingID := false if rule.Command.ID == "" { missingID = true feedback.Errors = append(feedback.Errors, rulesengine.Error{ Command: rule.Command.Name, Type: rulesengine.UnknownCommand, }) } if rule.Privilege.ID == "" { missingID = true feedback.Errors = append(feedback.Errors, rulesengine.Error{ Privilege: rule.Privilege.Name, Type: rulesengine.UnknownPrivilege, }) } if missingID { // Missing ID is found, move to next segment and rollback at end continue } log.Info( "Default rule already present in database", "commandName", rule.Command.Name, "privilegeName", rule.Privilege.Name, ) } } return feedback, err } func populateRuleSegmentIDs(ctx context.Context, tx *sql.Tx, segment rulesengine.RuleSegment) (rulesengine.RuleSegment, error) { rows, err := tx.QueryContext(ctx, datasql.GetIDsForRuleSegment, segment.Command.Name, segment.Privilege.Name) if err != nil { return segment, fmt.Errorf("error finding ID's (command %s, privilege %s): %w", segment.Command.Name, segment.Privilege.Name, err) } defer rows.Close() for rows.Next() { var kind, value string err := rows.Scan(&kind, &value) if err != nil { return segment, fmt.Errorf("error scanning for ID's: %w", err) } switch kind { case "command": segment.Command.ID = value case "privilege": segment.Privilege.ID = value default: return segment, fmt.Errorf("unknown kind for segment: %s", value) } } if err := rows.Err(); err != nil { return segment, fmt.Errorf("error finding ID's for segment: %w", err) } return segment, nil } func scanDefaultRulesRows(rows *sql.Rows) ([]rulesengine.RuleSegment, error) { var res []rulesengine.RuleSegment for rows.Next() { var privname string var privid string var comid string var comname string err := rows.Scan(&comname, &privname, &comid, &privid) if err != nil { return nil, fmt.Errorf("error in data:scanDefaultRulesRows: %v", err) } res = append(res, rulesengine.RuleSegment{ Command: rulesengine.Command{ Name: comname, ID: comid, }, Privilege: rulesengine.Privilege{ Name: privname, ID: privid, }, }) } err := rows.Err() if err != nil { err = fmt.Errorf("error in data:scanDefaultRulesRows on rows.Err: %v", err) } return res, err } func (ds Dataset) ReadAllDefaultRules(ctx context.Context) ([]rulesengine.RuleSegment, error) { rows, err := ds.db.QueryContext(ctx, datasql.SelectAllDefaultRules) if err != nil { return nil, fmt.Errorf("error in data:ReadAllDefaultRules: %v", err) } defer rows.Close() return scanDefaultRulesRows(rows) } func assembleRules(ruleSegments []rulesengine.RuleSegment) []rulesengine.Rule { rules := []rulesengine.Rule{} for _, ruleSegment := range ruleSegments { rules = appendRule(rules, ruleSegment) } return rules } func appendRule(rules []rulesengine.Rule, ruleSegment rulesengine.RuleSegment) []rulesengine.Rule { if len(rules) == 0 { return []rulesengine.Rule{{Command: ruleSegment.Command, Privileges: []rulesengine.Privilege{ruleSegment.Privilege}}} } for idx, rule := range rules { if rule.Command.ID == ruleSegment.Command.ID { // dataset handles duplicate instances for us so we dont need to check here rules[idx].Privileges = append(rule.Privileges, ruleSegment.Privilege) return rules } } rules = append(rules, rulesengine.Rule{Command: ruleSegment.Command, Privileges: []rulesengine.Privilege{ruleSegment.Privilege}}) return rules } // ReadDefaultRulesForCommand returns list of BannerRuleSegments. // ReadDefaultRulesForCommand Will not return the command if it no default rules are associated. func (ds Dataset) ReadDefaultRulesForCommand(ctx context.Context, commandName string) ([]rulesengine.RuleSegment, error) { rows, err := ds.db.QueryContext(ctx, datasql.SelectDefaultRulesByCommandName, commandName) if err != nil { return []rulesengine.RuleSegment{}, err } defer rows.Close() return scanDefaultRulesRows(rows) } func (ds Dataset) DeleteDefaultRule(ctx context.Context, commandName, privilegeName string) (res rulesengine.DeleteResult, err error) { res, err = ds.deleteValue(ctx, datasql.DeleteDefaultRule, commandName, privilegeName) if err != nil { return res, fmt.Errorf("error deleting privilege from default rule: %w", err) } if res.RowsAffected != 0 || len(res.Errors) != 0 { // Completed successfully or with a conflict return res, nil } // No changes were made and no errors occurred, this may be due to an unknown value, // or a rule that does not exist, find the cause tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) if err != nil { return res, fmt.Errorf("error creating transaction: %w", err) } defer func() { switch { case err != nil: if txErr := tx.Rollback(); txErr != nil { err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr) } default: if txErr := tx.Commit(); txErr != nil { err = fmt.Errorf("error committing transaction: %w", txErr) } } }() segment := rulesengine.RuleSegment{ Command: rulesengine.Command{Name: commandName}, Privilege: rulesengine.Privilege{Name: privilegeName}, } segment, err = populateRuleSegmentIDs(ctx, tx, segment) if err != nil { return res, fmt.Errorf("error discovering ID's: %w", err) } res.Errors = checkUnknownDeleteRuleErrs(segment) return res, err } func checkUnknownDeleteRuleErrs(segment rulesengine.RuleSegment) []rulesengine.Error { var errs []rulesengine.Error if segment.Command.ID == "" { errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownCommand, Command: segment.Command.Name}) } if segment.Privilege.ID == "" { errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownPrivilege, Privilege: segment.Privilege.Name}) } if len(errs) == 0 { errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownRule, Command: segment.Command.Name, Privilege: segment.Privilege.Name}) } return errs }