...

Source file src/github.com/jackc/pgx/v5/internal/pgmock/pgmock.go

Documentation: github.com/jackc/pgx/v5/internal/pgmock

     1  // Package pgmock provides the ability to mock a PostgreSQL server.
     2  package pgmock
     3  
     4  import (
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  
     9  	"github.com/jackc/pgx/v5/pgproto3"
    10  )
    11  
    12  type Step interface {
    13  	Step(*pgproto3.Backend) error
    14  }
    15  
    16  type Script struct {
    17  	Steps []Step
    18  }
    19  
    20  func (s *Script) Run(backend *pgproto3.Backend) error {
    21  	for _, step := range s.Steps {
    22  		err := step.Step(backend)
    23  		if err != nil {
    24  			return err
    25  		}
    26  	}
    27  
    28  	return nil
    29  }
    30  
    31  func (s *Script) Step(backend *pgproto3.Backend) error {
    32  	return s.Run(backend)
    33  }
    34  
    35  type expectMessageStep struct {
    36  	want pgproto3.FrontendMessage
    37  	any  bool
    38  }
    39  
    40  func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
    41  	msg, err := backend.Receive()
    42  	if err != nil {
    43  		return err
    44  	}
    45  
    46  	if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
    47  		return nil
    48  	}
    49  
    50  	if !reflect.DeepEqual(msg, e.want) {
    51  		return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
    52  	}
    53  
    54  	return nil
    55  }
    56  
    57  type expectStartupMessageStep struct {
    58  	want *pgproto3.StartupMessage
    59  	any  bool
    60  }
    61  
    62  func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
    63  	msg, err := backend.ReceiveStartupMessage()
    64  	if err != nil {
    65  		return err
    66  	}
    67  
    68  	if e.any {
    69  		return nil
    70  	}
    71  
    72  	if !reflect.DeepEqual(msg, e.want) {
    73  		return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
    74  	}
    75  
    76  	return nil
    77  }
    78  
    79  func ExpectMessage(want pgproto3.FrontendMessage) Step {
    80  	return expectMessage(want, false)
    81  }
    82  
    83  func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
    84  	return expectMessage(want, true)
    85  }
    86  
    87  func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
    88  	if want, ok := want.(*pgproto3.StartupMessage); ok {
    89  		return &expectStartupMessageStep{want: want, any: any}
    90  	}
    91  
    92  	return &expectMessageStep{want: want, any: any}
    93  }
    94  
    95  type sendMessageStep struct {
    96  	msg pgproto3.BackendMessage
    97  }
    98  
    99  func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
   100  	backend.Send(e.msg)
   101  	return backend.Flush()
   102  }
   103  
   104  func SendMessage(msg pgproto3.BackendMessage) Step {
   105  	return &sendMessageStep{msg: msg}
   106  }
   107  
   108  type waitForCloseMessageStep struct{}
   109  
   110  func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
   111  	for {
   112  		msg, err := backend.Receive()
   113  		if err == io.EOF {
   114  			return nil
   115  		} else if err != nil {
   116  			return err
   117  		}
   118  
   119  		if _, ok := msg.(*pgproto3.Terminate); ok {
   120  			return nil
   121  		}
   122  	}
   123  }
   124  
   125  func WaitForClose() Step {
   126  	return &waitForCloseMessageStep{}
   127  }
   128  
   129  func AcceptUnauthenticatedConnRequestSteps() []Step {
   130  	return []Step{
   131  		ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
   132  		SendMessage(&pgproto3.AuthenticationOk{}),
   133  		SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
   134  		SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
   135  	}
   136  }
   137  

View as plain text