...
1 package main
2
3 import (
4 "fmt"
5 "net"
6
7 "github.com/jackc/pgproto3/v2"
8 )
9
10 type PgFortuneBackend struct {
11 backend *pgproto3.Backend
12 conn net.Conn
13 responder func() ([]byte, error)
14 }
15
16 func NewPgFortuneBackend(conn net.Conn, responder func() ([]byte, error)) *PgFortuneBackend {
17 backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
18
19 connHandler := &PgFortuneBackend{
20 backend: backend,
21 conn: conn,
22 responder: responder,
23 }
24
25 return connHandler
26 }
27
28 func (p *PgFortuneBackend) Run() error {
29 defer p.Close()
30
31 err := p.handleStartup()
32 if err != nil {
33 return err
34 }
35
36 for {
37 msg, err := p.backend.Receive()
38 if err != nil {
39 return fmt.Errorf("error receiving message: %w", err)
40 }
41
42 switch msg.(type) {
43 case *pgproto3.Query:
44 response, err := p.responder()
45 if err != nil {
46 return fmt.Errorf("error generating query response: %w", err)
47 }
48
49 buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
50 {
51 Name: []byte("fortune"),
52 TableOID: 0,
53 TableAttributeNumber: 0,
54 DataTypeOID: 25,
55 DataTypeSize: -1,
56 TypeModifier: -1,
57 Format: 0,
58 },
59 }}).Encode(nil))
60 buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
61 buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
62 buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
63 _, err = p.conn.Write(buf)
64 if err != nil {
65 return fmt.Errorf("error writing query response: %w", err)
66 }
67 case *pgproto3.Terminate:
68 return nil
69 default:
70 return fmt.Errorf("received message other than Query from client: %#v", msg)
71 }
72 }
73 }
74
75 func (p *PgFortuneBackend) handleStartup() error {
76 startupMessage, err := p.backend.ReceiveStartupMessage()
77 if err != nil {
78 return fmt.Errorf("error receiving startup message: %w", err)
79 }
80
81 switch startupMessage.(type) {
82 case *pgproto3.StartupMessage:
83 buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
84 buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
85 _, err = p.conn.Write(buf)
86 if err != nil {
87 return fmt.Errorf("error sending ready for query: %w", err)
88 }
89 case *pgproto3.SSLRequest:
90 _, err = p.conn.Write([]byte("N"))
91 if err != nil {
92 return fmt.Errorf("error sending deny SSL request: %w", err)
93 }
94 return p.handleStartup()
95 default:
96 return fmt.Errorf("unknown startup message: %#v", startupMessage)
97 }
98
99 return nil
100 }
101
102 func (p *PgFortuneBackend) Close() error {
103 return p.conn.Close()
104 }
105
106 func mustEncode(buf []byte, err error) []byte {
107 if err != nil {
108 panic(err)
109 }
110 return buf
111 }
112
View as plain text