package database import ( "context" "database/sql" "fmt" "edge-infra.dev/pkg/lib/fog" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql" ) func (ds Dataset) ReadRulesForBanner(ctx context.Context, bannerName string) ([]rulesengine.RuleSegment, error) { rows, err := ds.db.QueryContext(ctx, datasql.SelectAllRulesForBanner, bannerName) if err != nil { return nil, fmt.Errorf("error reading banner rules: %w", err) } defer rows.Close() return scanDefaultRulesRows(rows) } func (ds Dataset) ReadBannerRulesForCommandAndBanner(ctx context.Context, bannerName string, commandName string) ([]rulesengine.RuleSegment, error) { rows, err := ds.db.QueryContext(ctx, datasql.SelectBannerRulesForCommandAndBanner, bannerName, commandName) if err != nil { return nil, fmt.Errorf("error reading banner rules: %w", err) } defer rows.Close() return scanBannerRulesRows(rows) } func (ds Dataset) ReadBannerRulesForCommand(ctx context.Context, commandName string) ([]rulesengine.RuleSegment, error) { rows, err := ds.db.QueryContext(ctx, datasql.SelectAllBannerRulesForCommand, commandName) if err != nil { return nil, fmt.Errorf("error reading banner rules: %w", err) } defer rows.Close() return scanBannerRulesRows(rows) } func (ds Dataset) DeletePrivilegeFromBannerRule(ctx context.Context, bannerName, commandName, privilegeName string) (rulesengine.DeleteResult, error) { res, err := ds.deleteValue(ctx, datasql.DeletePrivilegeFromBannerRule, bannerName, commandName, privilegeName) if err != nil { return res, fmt.Errorf("error deleting privilege from banner: %w", err) } if res.RowsAffected != 0 || len(res.Errors) != 0 { // Completed successfully or with a conflict return res, nil } // No changes were made and no errors occurred, this may be due to an unknown value, // or a rule that does not exist, find the cause tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) if err != nil { return res, fmt.Errorf("error creating transaction: %w", err) } segment := rulesengine.RuleSegment{ Banner: rulesengine.Banner{BannerName: bannerName}, Command: rulesengine.Command{Name: commandName}, Privilege: rulesengine.Privilege{Name: privilegeName}, } // Find All ID's for the segment segment, err = ds.populateBannerSegmentIDs(ctx, tx, segment) if err != nil { if txErr := tx.Rollback(); txErr != nil { err = fmt.Errorf("error rolling back transaction (%w): %w", txErr, err) } return res, fmt.Errorf("error discovering ID's: %w", err) } if err := tx.Commit(); err != nil { return res, fmt.Errorf("error committing read transaction: %w", err) } // Create UnknownType errors for missing ID's if segment.Command.ID == "" { res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownCommand, Command: segment.Command.Name}) } if segment.Privilege.ID == "" { res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownPrivilege, Privilege: segment.Privilege.Name}) } if segment.Banner.BannerID == "" { res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownBanner, Banner: segment.Banner.BannerName}) } // If all IDs are known but the delete made no changes this must be due to a non-existent segment if len(res.Errors) == 0 { res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownRule, Banner: bannerName, Command: commandName, Privilege: privilegeName}) } return res, nil } func (ds Dataset) ReadRulesForAllBanners(ctx context.Context) ([]rulesengine.RuleSegment, error) { rows, err := ds.db.QueryContext(ctx, datasql.SelectAllRulesFarAllBanners) if err != nil { return nil, fmt.Errorf("error querying for all banner rules: %w", err) } defer rows.Close() return scanBannerRulesRows(rows) } func scanBannerRulesRows(rows *sql.Rows) ([]rulesengine.RuleSegment, error) { var res []rulesengine.RuleSegment for rows.Next() { var b rulesengine.Banner var c rulesengine.Command var p rulesengine.Privilege err := rows.Scan(&b.BannerName, &c.Name, &p.Name, &b.BannerID, &c.ID, &p.ID) if err != nil { return nil, fmt.Errorf("error scanning banner row: %w", err) } res = append(res, rulesengine.RuleSegment{ Banner: b, Command: c, Privilege: p, }) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error while reading all banner rules: %w", err) } return res, nil } func (ds Dataset) AddBannerRules(ctx context.Context, rules []rulesengine.RuleSegment) (feedback rulesengine.AddRuleResult, err error) { tx, err := ds.db.BeginTx(ctx, nil) if err != nil { return feedback, fmt.Errorf("error beginning transaction: %w", err) } log := fog.FromContext(ctx).WithName("dataset") defer func() { switch { case err != nil: if txErr := tx.Rollback(); txErr != nil { err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr) } case len(feedback.Errors) != 0: if txErr := tx.Rollback(); txErr != nil { err = fmt.Errorf("error rolling back transaction due to errors: %w", txErr) } default: if txErr := tx.Commit(); txErr != nil { err = fmt.Errorf("error committing transaction: %w", txErr) } } }() for _, rule := range rules { var res sql.Result res, err = tx.ExecContext(ctx, datasql.InsertBannerRule, rule.Banner.BannerName, rule.Command.Name, rule.Privilege.Name) if err != nil { return feedback, fmt.Errorf("error executing query: %w", err) } var rowsAffected int64 rowsAffected, err = res.RowsAffected() if err != nil { return feedback, fmt.Errorf("error getting query result: %w", err) } // No changes were made to the DB, determine why by finding the ID of each // value to enter. If an ID is missing this means the value is invalid and // and error should be returned. If no ID is missing this means the rule is // already present in the DB if rowsAffected == 0 { rule, err = ds.populateBannerSegmentIDs(ctx, tx, rule) if err != nil { return feedback, err } errors, missingID := createErrors(rule) if missingID { // Missing ID is found, move to next segment and rollback at end feedback.Errors = append(feedback.Errors, errors...) continue } log.Info( "Rule already present in database", "bannerName", rule.Banner.BannerName, "commandName", rule.Command.Name, "privilegeName", rule.Privilege.Name, ) } } return feedback, nil } // Populates the ID fields of the passed in segment from the Name fields func (ds Dataset) populateBannerSegmentIDs(ctx context.Context, tx *sql.Tx, segment rulesengine.RuleSegment) (rulesengine.RuleSegment, error) { rows, err := tx.QueryContext(ctx, datasql.GetIDsForBannerSegment, segment.Banner.BannerName, segment.Command.Name, segment.Privilege.Name) if err != nil { return segment, fmt.Errorf("error finding ID's (banner %s, command %s, privilege %s): %w", segment.Banner.BannerName, segment.Command.Name, segment.Privilege.Name, err) } defer rows.Close() for rows.Next() { var kind, value string err := rows.Scan(&kind, &value) if err != nil { return segment, fmt.Errorf("error scanning for ID's: %w", err) } switch kind { case "banner": segment.Banner.BannerID = value case "command": segment.Command.ID = value case "privilege": segment.Privilege.ID = value default: return segment, fmt.Errorf("unknown kind for segment: %s", value) } } if err := rows.Err(); err != nil { return segment, fmt.Errorf("error finding ID's for sgement: %w", err) } return segment, nil } // Returns a slice of error values corresponding to the given rule segment // and true when a missing ID is found, else returns nil, false func createErrors(rule rulesengine.RuleSegment) ([]rulesengine.Error, bool) { missingID := false var errors []rulesengine.Error if rule.Banner.BannerID == "" { missingID = true errors = append(errors, rulesengine.Error{ Banner: rule.Banner.BannerName, Type: rulesengine.UnknownBanner, }) } if rule.Command.ID == "" { missingID = true errors = append(errors, rulesengine.Error{ Command: rule.Command.Name, Type: rulesengine.UnknownCommand, }) } if rule.Privilege.ID == "" { missingID = true errors = append(errors, rulesengine.Error{ Privilege: rule.Privilege.Name, Type: rulesengine.UnknownPrivilege, }) } return errors, missingID }