...
1
2 package pgmock
3
4 import (
5 "fmt"
6 "io"
7 "reflect"
8
9 "github.com/jackc/pgx/v5/pgproto3"
10 )
11
12 type Step interface {
13 Step(*pgproto3.Backend) error
14 }
15
16 type Script struct {
17 Steps []Step
18 }
19
20 func (s *Script) Run(backend *pgproto3.Backend) error {
21 for _, step := range s.Steps {
22 err := step.Step(backend)
23 if err != nil {
24 return err
25 }
26 }
27
28 return nil
29 }
30
31 func (s *Script) Step(backend *pgproto3.Backend) error {
32 return s.Run(backend)
33 }
34
35 type expectMessageStep struct {
36 want pgproto3.FrontendMessage
37 any bool
38 }
39
40 func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
41 msg, err := backend.Receive()
42 if err != nil {
43 return err
44 }
45
46 if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
47 return nil
48 }
49
50 if !reflect.DeepEqual(msg, e.want) {
51 return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
52 }
53
54 return nil
55 }
56
57 type expectStartupMessageStep struct {
58 want *pgproto3.StartupMessage
59 any bool
60 }
61
62 func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
63 msg, err := backend.ReceiveStartupMessage()
64 if err != nil {
65 return err
66 }
67
68 if e.any {
69 return nil
70 }
71
72 if !reflect.DeepEqual(msg, e.want) {
73 return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
74 }
75
76 return nil
77 }
78
79 func ExpectMessage(want pgproto3.FrontendMessage) Step {
80 return expectMessage(want, false)
81 }
82
83 func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
84 return expectMessage(want, true)
85 }
86
87 func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
88 if want, ok := want.(*pgproto3.StartupMessage); ok {
89 return &expectStartupMessageStep{want: want, any: any}
90 }
91
92 return &expectMessageStep{want: want, any: any}
93 }
94
95 type sendMessageStep struct {
96 msg pgproto3.BackendMessage
97 }
98
99 func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
100 backend.Send(e.msg)
101 return backend.Flush()
102 }
103
104 func SendMessage(msg pgproto3.BackendMessage) Step {
105 return &sendMessageStep{msg: msg}
106 }
107
108 type waitForCloseMessageStep struct{}
109
110 func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
111 for {
112 msg, err := backend.Receive()
113 if err == io.EOF {
114 return nil
115 } else if err != nil {
116 return err
117 }
118
119 if _, ok := msg.(*pgproto3.Terminate); ok {
120 return nil
121 }
122 }
123 }
124
125 func WaitForClose() Step {
126 return &waitForCloseMessageStep{}
127 }
128
129 func AcceptUnauthenticatedConnRequestSteps() []Step {
130 return []Step{
131 ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
132 SendMessage(&pgproto3.AuthenticationOk{}),
133 SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
134 SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
135 }
136 }
137
View as plain text