...
1 package pgproto3
2
3 import (
4 "bytes"
5 "encoding/binary"
6 "encoding/json"
7 "errors"
8 "fmt"
9
10 "github.com/jackc/pgio"
11 )
12
13 const ProtocolVersionNumber = 196608
14
15 type StartupMessage struct {
16 ProtocolVersion uint32
17 Parameters map[string]string
18 }
19
20
21 func (*StartupMessage) Frontend() {}
22
23
24
25 func (dst *StartupMessage) Decode(src []byte) error {
26 if len(src) < 4 {
27 return errors.New("startup message too short")
28 }
29
30 dst.ProtocolVersion = binary.BigEndian.Uint32(src)
31 rp := 4
32
33 if dst.ProtocolVersion != ProtocolVersionNumber {
34 return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
35 }
36
37 dst.Parameters = make(map[string]string)
38 for {
39 idx := bytes.IndexByte(src[rp:], 0)
40 if idx < 0 {
41 return &invalidMessageFormatErr{messageType: "StartupMesage"}
42 }
43 key := string(src[rp : rp+idx])
44 rp += idx + 1
45
46 idx = bytes.IndexByte(src[rp:], 0)
47 if idx < 0 {
48 return &invalidMessageFormatErr{messageType: "StartupMesage"}
49 }
50 value := string(src[rp : rp+idx])
51 rp += idx + 1
52
53 dst.Parameters[key] = value
54
55 if len(src[rp:]) == 1 {
56 if src[rp] != 0 {
57 return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
58 }
59 break
60 }
61 }
62
63 return nil
64 }
65
66
67 func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
68 sp := len(dst)
69 dst = pgio.AppendInt32(dst, -1)
70
71 dst = pgio.AppendUint32(dst, src.ProtocolVersion)
72 for k, v := range src.Parameters {
73 dst = append(dst, k...)
74 dst = append(dst, 0)
75 dst = append(dst, v...)
76 dst = append(dst, 0)
77 }
78 dst = append(dst, 0)
79
80 return finishMessage(dst, sp)
81 }
82
83
84 func (src StartupMessage) MarshalJSON() ([]byte, error) {
85 return json.Marshal(struct {
86 Type string
87 ProtocolVersion uint32
88 Parameters map[string]string
89 }{
90 Type: "StartupMessage",
91 ProtocolVersion: src.ProtocolVersion,
92 Parameters: src.Parameters,
93 })
94 }
95
View as plain text