...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/banner_rules.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database

     1  package database
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  
     8  	"edge-infra.dev/pkg/lib/fog"
     9  	rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
    10  	datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql"
    11  )
    12  
    13  func (ds Dataset) ReadRulesForBanner(ctx context.Context, bannerName string) ([]rulesengine.RuleSegment, error) {
    14  	rows, err := ds.db.QueryContext(ctx, datasql.SelectAllRulesForBanner, bannerName)
    15  	if err != nil {
    16  		return nil, fmt.Errorf("error reading banner rules: %w", err)
    17  	}
    18  	defer rows.Close()
    19  
    20  	return scanDefaultRulesRows(rows)
    21  }
    22  
    23  func (ds Dataset) ReadBannerRulesForCommandAndBanner(ctx context.Context, bannerName string, commandName string) ([]rulesengine.RuleSegment, error) {
    24  	rows, err := ds.db.QueryContext(ctx, datasql.SelectBannerRulesForCommandAndBanner, bannerName, commandName)
    25  	if err != nil {
    26  		return nil, fmt.Errorf("error reading banner rules: %w", err)
    27  	}
    28  	defer rows.Close()
    29  
    30  	return scanBannerRulesRows(rows)
    31  }
    32  
    33  func (ds Dataset) ReadBannerRulesForCommand(ctx context.Context, commandName string) ([]rulesengine.RuleSegment, error) {
    34  	rows, err := ds.db.QueryContext(ctx, datasql.SelectAllBannerRulesForCommand, commandName)
    35  	if err != nil {
    36  		return nil, fmt.Errorf("error reading banner rules: %w", err)
    37  	}
    38  	defer rows.Close()
    39  
    40  	return scanBannerRulesRows(rows)
    41  }
    42  
    43  func (ds Dataset) DeletePrivilegeFromBannerRule(ctx context.Context, bannerName, commandName, privilegeName string) (rulesengine.DeleteResult, error) {
    44  	res, err := ds.deleteValue(ctx, datasql.DeletePrivilegeFromBannerRule, bannerName, commandName, privilegeName)
    45  	if err != nil {
    46  		return res, fmt.Errorf("error deleting privilege from banner: %w", err)
    47  	}
    48  
    49  	if res.RowsAffected != 0 || len(res.Errors) != 0 {
    50  		// Completed successfully or with a conflict
    51  		return res, nil
    52  	}
    53  
    54  	// No changes were made and no errors occurred, this may be due to an unknown value,
    55  	// or a rule that does not exist, find the cause
    56  
    57  	tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
    58  	if err != nil {
    59  		return res, fmt.Errorf("error creating transaction: %w", err)
    60  	}
    61  
    62  	segment := rulesengine.RuleSegment{
    63  		Banner:    rulesengine.Banner{BannerName: bannerName},
    64  		Command:   rulesengine.Command{Name: commandName},
    65  		Privilege: rulesengine.Privilege{Name: privilegeName},
    66  	}
    67  
    68  	// Find All ID's for the segment
    69  	segment, err = ds.populateBannerSegmentIDs(ctx, tx, segment)
    70  	if err != nil {
    71  		if txErr := tx.Rollback(); txErr != nil {
    72  			err = fmt.Errorf("error rolling back transaction (%w): %w", txErr, err)
    73  		}
    74  		return res, fmt.Errorf("error discovering ID's: %w", err)
    75  	}
    76  
    77  	if err := tx.Commit(); err != nil {
    78  		return res, fmt.Errorf("error committing read transaction: %w", err)
    79  	}
    80  
    81  	// Create UnknownType errors for missing ID's
    82  	if segment.Command.ID == "" {
    83  		res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownCommand, Command: segment.Command.Name})
    84  	}
    85  	if segment.Privilege.ID == "" {
    86  		res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownPrivilege, Privilege: segment.Privilege.Name})
    87  	}
    88  	if segment.Banner.BannerID == "" {
    89  		res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownBanner, Banner: segment.Banner.BannerName})
    90  	}
    91  
    92  	// If all IDs are known but the delete made no changes this must be due to a non-existent segment
    93  	if len(res.Errors) == 0 {
    94  		res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownRule, Banner: bannerName, Command: commandName, Privilege: privilegeName})
    95  	}
    96  
    97  	return res, nil
    98  }
    99  
   100  func (ds Dataset) ReadRulesForAllBanners(ctx context.Context) ([]rulesengine.RuleSegment, error) {
   101  	rows, err := ds.db.QueryContext(ctx, datasql.SelectAllRulesFarAllBanners)
   102  	if err != nil {
   103  		return nil, fmt.Errorf("error querying for all banner rules: %w", err)
   104  	}
   105  	defer rows.Close()
   106  
   107  	return scanBannerRulesRows(rows)
   108  }
   109  
   110  func scanBannerRulesRows(rows *sql.Rows) ([]rulesengine.RuleSegment, error) {
   111  	var res []rulesengine.RuleSegment
   112  	for rows.Next() {
   113  		var b rulesengine.Banner
   114  		var c rulesengine.Command
   115  		var p rulesengine.Privilege
   116  
   117  		err := rows.Scan(&b.BannerName, &c.Name, &p.Name, &b.BannerID, &c.ID, &p.ID)
   118  		if err != nil {
   119  			return nil, fmt.Errorf("error scanning banner row: %w", err)
   120  		}
   121  
   122  		res = append(res, rulesengine.RuleSegment{
   123  			Banner:    b,
   124  			Command:   c,
   125  			Privilege: p,
   126  		})
   127  	}
   128  
   129  	if err := rows.Err(); err != nil {
   130  		return nil, fmt.Errorf("error while reading all banner rules: %w", err)
   131  	}
   132  
   133  	return res, nil
   134  }
   135  
   136  func (ds Dataset) AddBannerRules(ctx context.Context, rules []rulesengine.RuleSegment) (feedback rulesengine.AddRuleResult, err error) {
   137  	tx, err := ds.db.BeginTx(ctx, nil)
   138  	if err != nil {
   139  		return feedback, fmt.Errorf("error beginning transaction: %w", err)
   140  	}
   141  
   142  	log := fog.FromContext(ctx).WithName("dataset")
   143  
   144  	defer func() {
   145  		switch {
   146  		case err != nil:
   147  			if txErr := tx.Rollback(); txErr != nil {
   148  				err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr)
   149  			}
   150  		case len(feedback.Errors) != 0:
   151  			if txErr := tx.Rollback(); txErr != nil {
   152  				err = fmt.Errorf("error rolling back transaction due to errors: %w", txErr)
   153  			}
   154  		default:
   155  			if txErr := tx.Commit(); txErr != nil {
   156  				err = fmt.Errorf("error committing transaction: %w", txErr)
   157  			}
   158  		}
   159  	}()
   160  
   161  	for _, rule := range rules {
   162  		var res sql.Result
   163  		res, err = tx.ExecContext(ctx, datasql.InsertBannerRule, rule.Banner.BannerName, rule.Command.Name, rule.Privilege.Name)
   164  		if err != nil {
   165  			return feedback, fmt.Errorf("error executing query: %w", err)
   166  		}
   167  
   168  		var rowsAffected int64
   169  		rowsAffected, err = res.RowsAffected()
   170  		if err != nil {
   171  			return feedback, fmt.Errorf("error getting query result: %w", err)
   172  		}
   173  
   174  		// No changes were made to the DB, determine why by finding the ID of each
   175  		// value to enter. If an ID is missing this means the value is invalid and
   176  		// and error should be returned. If no ID is missing this means the rule is
   177  		// already present in the DB
   178  		if rowsAffected == 0 {
   179  			rule, err = ds.populateBannerSegmentIDs(ctx, tx, rule)
   180  			if err != nil {
   181  				return feedback, err
   182  			}
   183  
   184  			errors, missingID := createErrors(rule)
   185  
   186  			if missingID {
   187  				// Missing ID is found, move to next segment and rollback at end
   188  				feedback.Errors = append(feedback.Errors, errors...)
   189  				continue
   190  			}
   191  
   192  			log.Info(
   193  				"Rule already present in database",
   194  				"bannerName", rule.Banner.BannerName,
   195  				"commandName", rule.Command.Name,
   196  				"privilegeName", rule.Privilege.Name,
   197  			)
   198  		}
   199  	}
   200  
   201  	return feedback, nil
   202  }
   203  
   204  // Populates the ID fields of the passed in segment from the Name fields
   205  func (ds Dataset) populateBannerSegmentIDs(ctx context.Context, tx *sql.Tx, segment rulesengine.RuleSegment) (rulesengine.RuleSegment, error) {
   206  	rows, err := tx.QueryContext(ctx, datasql.GetIDsForBannerSegment, segment.Banner.BannerName, segment.Command.Name, segment.Privilege.Name)
   207  	if err != nil {
   208  		return segment, fmt.Errorf("error finding ID's (banner %s, command %s, privilege %s): %w", segment.Banner.BannerName, segment.Command.Name, segment.Privilege.Name, err)
   209  	}
   210  	defer rows.Close()
   211  
   212  	for rows.Next() {
   213  		var kind, value string
   214  		err := rows.Scan(&kind, &value)
   215  		if err != nil {
   216  			return segment, fmt.Errorf("error scanning for ID's: %w", err)
   217  		}
   218  
   219  		switch kind {
   220  		case "banner":
   221  			segment.Banner.BannerID = value
   222  		case "command":
   223  			segment.Command.ID = value
   224  		case "privilege":
   225  			segment.Privilege.ID = value
   226  		default:
   227  			return segment, fmt.Errorf("unknown kind for segment: %s", value)
   228  		}
   229  	}
   230  
   231  	if err := rows.Err(); err != nil {
   232  		return segment, fmt.Errorf("error finding ID's for sgement: %w", err)
   233  	}
   234  
   235  	return segment, nil
   236  }
   237  
   238  // Returns a slice of error values corresponding to the given rule segment
   239  // and true when a missing ID is found, else returns nil, false
   240  func createErrors(rule rulesengine.RuleSegment) ([]rulesengine.Error, bool) {
   241  	missingID := false
   242  	var errors []rulesengine.Error
   243  
   244  	if rule.Banner.BannerID == "" {
   245  		missingID = true
   246  		errors = append(errors, rulesengine.Error{
   247  			Banner: rule.Banner.BannerName,
   248  			Type:   rulesengine.UnknownBanner,
   249  		})
   250  	}
   251  
   252  	if rule.Command.ID == "" {
   253  		missingID = true
   254  		errors = append(errors, rulesengine.Error{
   255  			Command: rule.Command.Name,
   256  			Type:    rulesengine.UnknownCommand,
   257  		})
   258  	}
   259  
   260  	if rule.Privilege.ID == "" {
   261  		missingID = true
   262  		errors = append(errors, rulesengine.Error{
   263  			Privilege: rule.Privilege.Name,
   264  			Type:      rulesengine.UnknownPrivilege,
   265  		})
   266  	}
   267  
   268  	return errors, missingID
   269  }
   270  

View as plain text