...

Source file src/github.com/jackc/pgx/v5/pgproto3/bind.go

Documentation: github.com/jackc/pgx/v5/pgproto3

     1  package pgproto3
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"math"
    11  
    12  	"github.com/jackc/pgx/v5/internal/pgio"
    13  )
    14  
    15  type Bind struct {
    16  	DestinationPortal    string
    17  	PreparedStatement    string
    18  	ParameterFormatCodes []int16
    19  	Parameters           [][]byte
    20  	ResultFormatCodes    []int16
    21  }
    22  
    23  // Frontend identifies this message as sendable by a PostgreSQL frontend.
    24  func (*Bind) Frontend() {}
    25  
    26  // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
    27  // type identifier and 4 byte message length.
    28  func (dst *Bind) Decode(src []byte) error {
    29  	*dst = Bind{}
    30  
    31  	idx := bytes.IndexByte(src, 0)
    32  	if idx < 0 {
    33  		return &invalidMessageFormatErr{messageType: "Bind"}
    34  	}
    35  	dst.DestinationPortal = string(src[:idx])
    36  	rp := idx + 1
    37  
    38  	idx = bytes.IndexByte(src[rp:], 0)
    39  	if idx < 0 {
    40  		return &invalidMessageFormatErr{messageType: "Bind"}
    41  	}
    42  	dst.PreparedStatement = string(src[rp : rp+idx])
    43  	rp += idx + 1
    44  
    45  	if len(src[rp:]) < 2 {
    46  		return &invalidMessageFormatErr{messageType: "Bind"}
    47  	}
    48  	parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
    49  	rp += 2
    50  
    51  	if parameterFormatCodeCount > 0 {
    52  		dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
    53  
    54  		if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
    55  			return &invalidMessageFormatErr{messageType: "Bind"}
    56  		}
    57  		for i := 0; i < parameterFormatCodeCount; i++ {
    58  			dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
    59  			rp += 2
    60  		}
    61  	}
    62  
    63  	if len(src[rp:]) < 2 {
    64  		return &invalidMessageFormatErr{messageType: "Bind"}
    65  	}
    66  	parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
    67  	rp += 2
    68  
    69  	if parameterCount > 0 {
    70  		dst.Parameters = make([][]byte, parameterCount)
    71  
    72  		for i := 0; i < parameterCount; i++ {
    73  			if len(src[rp:]) < 4 {
    74  				return &invalidMessageFormatErr{messageType: "Bind"}
    75  			}
    76  
    77  			msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
    78  			rp += 4
    79  
    80  			// null
    81  			if msgSize == -1 {
    82  				continue
    83  			}
    84  
    85  			if len(src[rp:]) < msgSize {
    86  				return &invalidMessageFormatErr{messageType: "Bind"}
    87  			}
    88  
    89  			dst.Parameters[i] = src[rp : rp+msgSize]
    90  			rp += msgSize
    91  		}
    92  	}
    93  
    94  	if len(src[rp:]) < 2 {
    95  		return &invalidMessageFormatErr{messageType: "Bind"}
    96  	}
    97  	resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
    98  	rp += 2
    99  
   100  	dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
   101  	if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
   102  		return &invalidMessageFormatErr{messageType: "Bind"}
   103  	}
   104  	for i := 0; i < resultFormatCodeCount; i++ {
   105  		dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
   106  		rp += 2
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
   113  func (src *Bind) Encode(dst []byte) ([]byte, error) {
   114  	dst, sp := beginMessage(dst, 'B')
   115  
   116  	dst = append(dst, src.DestinationPortal...)
   117  	dst = append(dst, 0)
   118  	dst = append(dst, src.PreparedStatement...)
   119  	dst = append(dst, 0)
   120  
   121  	if len(src.ParameterFormatCodes) > math.MaxUint16 {
   122  		return nil, errors.New("too many parameter format codes")
   123  	}
   124  	dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
   125  	for _, fc := range src.ParameterFormatCodes {
   126  		dst = pgio.AppendInt16(dst, fc)
   127  	}
   128  
   129  	if len(src.Parameters) > math.MaxUint16 {
   130  		return nil, errors.New("too many parameters")
   131  	}
   132  	dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
   133  	for _, p := range src.Parameters {
   134  		if p == nil {
   135  			dst = pgio.AppendInt32(dst, -1)
   136  			continue
   137  		}
   138  
   139  		dst = pgio.AppendInt32(dst, int32(len(p)))
   140  		dst = append(dst, p...)
   141  	}
   142  
   143  	if len(src.ResultFormatCodes) > math.MaxUint16 {
   144  		return nil, errors.New("too many result format codes")
   145  	}
   146  	dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
   147  	for _, fc := range src.ResultFormatCodes {
   148  		dst = pgio.AppendInt16(dst, fc)
   149  	}
   150  
   151  	return finishMessage(dst, sp)
   152  }
   153  
   154  // MarshalJSON implements encoding/json.Marshaler.
   155  func (src Bind) MarshalJSON() ([]byte, error) {
   156  	formattedParameters := make([]map[string]string, len(src.Parameters))
   157  	for i, p := range src.Parameters {
   158  		if p == nil {
   159  			continue
   160  		}
   161  
   162  		textFormat := true
   163  		if len(src.ParameterFormatCodes) == 1 {
   164  			textFormat = src.ParameterFormatCodes[0] == 0
   165  		} else if len(src.ParameterFormatCodes) > 1 {
   166  			textFormat = src.ParameterFormatCodes[i] == 0
   167  		}
   168  
   169  		if textFormat {
   170  			formattedParameters[i] = map[string]string{"text": string(p)}
   171  		} else {
   172  			formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
   173  		}
   174  	}
   175  
   176  	return json.Marshal(struct {
   177  		Type                 string
   178  		DestinationPortal    string
   179  		PreparedStatement    string
   180  		ParameterFormatCodes []int16
   181  		Parameters           []map[string]string
   182  		ResultFormatCodes    []int16
   183  	}{
   184  		Type:                 "Bind",
   185  		DestinationPortal:    src.DestinationPortal,
   186  		PreparedStatement:    src.PreparedStatement,
   187  		ParameterFormatCodes: src.ParameterFormatCodes,
   188  		Parameters:           formattedParameters,
   189  		ResultFormatCodes:    src.ResultFormatCodes,
   190  	})
   191  }
   192  
   193  // UnmarshalJSON implements encoding/json.Unmarshaler.
   194  func (dst *Bind) UnmarshalJSON(data []byte) error {
   195  	// Ignore null, like in the main JSON package.
   196  	if string(data) == "null" {
   197  		return nil
   198  	}
   199  
   200  	var msg struct {
   201  		DestinationPortal    string
   202  		PreparedStatement    string
   203  		ParameterFormatCodes []int16
   204  		Parameters           []map[string]string
   205  		ResultFormatCodes    []int16
   206  	}
   207  	err := json.Unmarshal(data, &msg)
   208  	if err != nil {
   209  		return err
   210  	}
   211  	dst.DestinationPortal = msg.DestinationPortal
   212  	dst.PreparedStatement = msg.PreparedStatement
   213  	dst.ParameterFormatCodes = msg.ParameterFormatCodes
   214  	dst.Parameters = make([][]byte, len(msg.Parameters))
   215  	dst.ResultFormatCodes = msg.ResultFormatCodes
   216  	for n, parameter := range msg.Parameters {
   217  		dst.Parameters[n], err = getValueFromJSON(parameter)
   218  		if err != nil {
   219  			return fmt.Errorf("cannot get param %d: %w", n, err)
   220  		}
   221  	}
   222  	return nil
   223  }
   224  

View as plain text