package services import ( "context" "database/sql" "fmt" "github.com/lib/pq" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/lib/edgeutils" ) //go:generate mockgen -destination=../mocks/mock_log_replay_service.go -package=mocks edge-infra.dev/pkg/edge/api/services LogReplayService type LogReplayService interface { GetLogReplay(ctx context.Context, logReplayID string) (*model.LogReplay, error) GetLogReplays(ctx context.Context, clusterEdgeID string, unexecutedLogReplays *bool) ([]*model.LogReplay, error) GetLogReplayJobs(ctx context.Context, clusterEdgeID string) ([]*model.LogReplayJob, error) CreateLogReplay(ctx context.Context, clusterEdgeID string, inputPayload model.CreateLogReplayPayload) (bool, error) UpdateLogReplay(ctx context.Context, logReplayID string, clusterEdgeID string, inputPayload model.UpdateLogReplayPayload) (bool, error) DeleteLogReplay(ctx context.Context, logReplayID string) (bool, error) } type logReplayService struct { SQLDB *sql.DB } func (l *logReplayService) GetLogReplay(ctx context.Context, logReplayID string) (*model.LogReplay, error) { row := l.SQLDB.QueryRowContext(ctx, sqlquery.GetLogReplay, logReplayID) lr := &model.LogReplay{} lr.Namespaces = []string{} if err := row.Scan(&lr.LogReplayID, &lr.ClusterEdgeID, pq.Array(&lr.Namespaces), &lr.LogLevel, &lr.StartTime, &lr.EndTime, &lr.Queued, &lr.Executed, &lr.Status, &lr.UpdatedAt); err != nil { return nil, fmt.Errorf("error getting log replay: SQL error %w", err) } startFormatted, err := edgeutils.ConvertToRFC3339(lr.StartTime) if err != nil { return nil, fmt.Errorf("error getting log replay: %w", err) } endFormatted, err := edgeutils.ConvertToRFC3339(lr.EndTime) if err != nil { return nil, fmt.Errorf("error getting log replay: %w", err) } updatedFormatted, err := edgeutils.ConvertToRFC3339(lr.UpdatedAt) if err != nil { return nil, fmt.Errorf("error getting log replay: %w", err) } lr.StartTime = startFormatted lr.EndTime = endFormatted lr.UpdatedAt = updatedFormatted return lr, nil } func (l *logReplayService) GetLogReplayJobs(ctx context.Context, clusterEdgeID string) ([]*model.LogReplayJob, error) { rows, err := l.SQLDB.QueryContext(ctx, sqlquery.GetLogReplayJobs, clusterEdgeID) if err != nil { return nil, fmt.Errorf("error getting log replay jobs: SQL error %w", err) } var logReplays []*model.LogReplayJob defer rows.Close() for rows.Next() { var lrj model.LogReplayJob if err := rows.Scan(&lrj.Jsonpath, &lrj.Value, &lrj.Missing, &lrj.Name, &lrj.Queued, &lrj.Executed, &lrj.Status, &lrj.UpdatedAt, &lrj.LogReplayID); err != nil { return nil, fmt.Errorf("error getting log replay jobs not executed: row scan error %w", err) } updatedFormatted, err := edgeutils.ConvertToRFC3339(lrj.UpdatedAt) if err != nil { return nil, fmt.Errorf("error getting log replay jobs: %w", err) } lrj.UpdatedAt = updatedFormatted logReplays = append(logReplays, &lrj) } return logReplays, nil } func (l *logReplayService) GetLogReplays(ctx context.Context, clusterEdgeID string, unexecutedLogReplays *bool) ([]*model.LogReplay, error) { if unexecutedLogReplays == nil || !*unexecutedLogReplays { return l.getLogReplays(ctx, clusterEdgeID, sqlquery.GetLogReplays) } return l.getLogReplays(ctx, clusterEdgeID, sqlquery.GetLogReplaysNotExecuted) } func (l *logReplayService) DeleteLogReplay(ctx context.Context, logReplayID string) (bool, error) { _, err := l.SQLDB.ExecContext(ctx, sqlquery.DeleteLogReplay, logReplayID) if err != nil { return false, fmt.Errorf("error deleting log replay: %w", err) } return true, nil } func (l *logReplayService) CreateLogReplay(ctx context.Context, clusterEdgeID string, clrp model.CreateLogReplayPayload) (bool, error) { err := l.validateInput(clrp) if err != nil { return false, fmt.Errorf("error updating log replay: %w", err) } logLevel := clrp.LogLevel.String() _, err = l.SQLDB.ExecContext(ctx, sqlquery.CreateLogReplay, clusterEdgeID, pq.Array(&clrp.Namespaces), logLevel, clrp.StartTime, clrp.EndTime, false, false, model.LogReplayStatusNotStarted.String()) if err != nil { return false, fmt.Errorf("error creating log replay: %w", err) } return true, nil } func (l *logReplayService) UpdateLogReplay(ctx context.Context, logReplayID string, clusterEdgeID string, inputPayload model.UpdateLogReplayPayload) (bool, error) { lr, err := l.GetLogReplay(ctx, logReplayID) if err != nil { return false, fmt.Errorf("error updating log replay: %w", err) } q := lr.Queued e := lr.Executed status := lr.Status // We want to keep the UpdatedAt Status to be last true update to the status. // If the flag didn't exist it would be updated every 3 mins by the reconcile loop change := false if inputPayload.Queued != nil && q != *inputPayload.Queued { change = true q = *inputPayload.Queued } if inputPayload.Executed != nil && e != *inputPayload.Executed { change = true e = *inputPayload.Executed } if inputPayload.Status != nil && status != inputPayload.Status.String() { change = true status = inputPayload.Status.String() } if change { _, err = l.SQLDB.ExecContext(ctx, sqlquery.UpdateLogReplay, logReplayID, clusterEdgeID, q, e, status) if err != nil { return false, fmt.Errorf("error updating log replay: %w", err) } } return change, nil } func (l logReplayService) validateInput(inputPayload model.CreateLogReplayPayload) error { err := edgeutils.IsValidTimestamp(inputPayload.StartTime) if err != nil { return fmt.Errorf("error validating start time format: %w", err) } err = edgeutils.IsValidTimestamp(inputPayload.EndTime) if err != nil { return fmt.Errorf("error validating end time format: %w", err) } err = edgeutils.TimeSequenceCheck(inputPayload.StartTime, inputPayload.EndTime) if err != nil { return fmt.Errorf("error validating time orderings: %w", err) } return nil } func (l *logReplayService) getLogReplays(ctx context.Context, clusterEdgeID string, query string) ([]*model.LogReplay, error) { row, err := l.SQLDB.QueryContext(ctx, query, clusterEdgeID) if err != nil { return nil, fmt.Errorf("error getting log replays: SQL error %w", err) } var logReplays []*model.LogReplay defer row.Close() for row.Next() { var lr model.LogReplay lr.Namespaces = []string{} if err := row.Scan(&lr.LogReplayID, &lr.ClusterEdgeID, pq.Array(&lr.Namespaces), &lr.LogLevel, &lr.StartTime, &lr.EndTime, &lr.Queued, &lr.Executed, &lr.Status, &lr.UpdatedAt); err != nil { return nil, fmt.Errorf("error getting log replays not executed: row scan error %w", err) } startFormatted, err := edgeutils.ConvertToRFC3339(lr.StartTime) if err != nil { return nil, fmt.Errorf("error getting log replay: %w", err) } endFormatted, err := edgeutils.ConvertToRFC3339(lr.EndTime) if err != nil { return nil, fmt.Errorf("error getting log replay: %w", err) } updatedFormatted, err := edgeutils.ConvertToRFC3339(lr.UpdatedAt) if err != nil { return nil, fmt.Errorf("error getting log replay: %w", err) } lr.StartTime = startFormatted lr.EndTime = endFormatted lr.UpdatedAt = updatedFormatted logReplays = append(logReplays, &lr) } return logReplays, nil } // nolint func NewLogReplayService(sqlDB *sql.DB) *logReplayService { return &logReplayService{ SQLDB: sqlDB, } }