...

Source file src/github.com/jackc/pgx/v4/values.go

Documentation: github.com/jackc/pgx/v4

     1  package pgx
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"fmt"
     6  	"math"
     7  	"reflect"
     8  	"time"
     9  
    10  	"github.com/jackc/pgio"
    11  	"github.com/jackc/pgtype"
    12  )
    13  
    14  // PostgreSQL format codes
    15  const (
    16  	TextFormatCode   = 0
    17  	BinaryFormatCode = 1
    18  )
    19  
    20  // SerializationError occurs on failure to encode or decode a value
    21  type SerializationError string
    22  
    23  func (e SerializationError) Error() string {
    24  	return string(e)
    25  }
    26  
    27  func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) {
    28  	if arg == nil {
    29  		return nil, nil
    30  	}
    31  
    32  	refVal := reflect.ValueOf(arg)
    33  	if refVal.Kind() == reflect.Ptr && refVal.IsNil() {
    34  		return nil, nil
    35  	}
    36  
    37  	switch arg := arg.(type) {
    38  
    39  	// https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface
    40  	// []byte to database/sql instead of string. But that caused problems with the
    41  	// simple protocol because the driver.Valuer case got taken before the
    42  	// pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual
    43  	// case because of https://github.com/jackc/pgx/issues/339. So instead we
    44  	// special case JSON and JSONB.
    45  	case *pgtype.JSON:
    46  		buf, err := arg.EncodeText(ci, nil)
    47  		if err != nil {
    48  			return nil, err
    49  		}
    50  		if buf == nil {
    51  			return nil, nil
    52  		}
    53  		return string(buf), nil
    54  	case *pgtype.JSONB:
    55  		buf, err := arg.EncodeText(ci, nil)
    56  		if err != nil {
    57  			return nil, err
    58  		}
    59  		if buf == nil {
    60  			return nil, nil
    61  		}
    62  		return string(buf), nil
    63  
    64  	case driver.Valuer:
    65  		return callValuerValue(arg)
    66  	case pgtype.TextEncoder:
    67  		buf, err := arg.EncodeText(ci, nil)
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  		if buf == nil {
    72  			return nil, nil
    73  		}
    74  		return string(buf), nil
    75  	case float32:
    76  		return float64(arg), nil
    77  	case float64:
    78  		return arg, nil
    79  	case bool:
    80  		return arg, nil
    81  	case time.Duration:
    82  		return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil
    83  	case time.Time:
    84  		return arg, nil
    85  	case string:
    86  		return arg, nil
    87  	case []byte:
    88  		return arg, nil
    89  	case int8:
    90  		return int64(arg), nil
    91  	case int16:
    92  		return int64(arg), nil
    93  	case int32:
    94  		return int64(arg), nil
    95  	case int64:
    96  		return arg, nil
    97  	case int:
    98  		return int64(arg), nil
    99  	case uint8:
   100  		return int64(arg), nil
   101  	case uint16:
   102  		return int64(arg), nil
   103  	case uint32:
   104  		return int64(arg), nil
   105  	case uint64:
   106  		if arg > math.MaxInt64 {
   107  			return nil, fmt.Errorf("arg too big for int64: %v", arg)
   108  		}
   109  		return int64(arg), nil
   110  	case uint:
   111  		if uint64(arg) > math.MaxInt64 {
   112  			return nil, fmt.Errorf("arg too big for int64: %v", arg)
   113  		}
   114  		return int64(arg), nil
   115  	}
   116  
   117  	if dt, found := ci.DataTypeForValue(arg); found {
   118  		v := dt.Value
   119  		err := v.Set(arg)
   120  		if err != nil {
   121  			return nil, err
   122  		}
   123  		buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil)
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  		if buf == nil {
   128  			return nil, nil
   129  		}
   130  		return string(buf), nil
   131  	}
   132  
   133  	if refVal.Kind() == reflect.Ptr {
   134  		arg = refVal.Elem().Interface()
   135  		return convertSimpleArgument(ci, arg)
   136  	}
   137  
   138  	if strippedArg, ok := stripNamedType(&refVal); ok {
   139  		return convertSimpleArgument(ci, strippedArg)
   140  	}
   141  	return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
   142  }
   143  
   144  func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
   145  	if arg == nil {
   146  		return pgio.AppendInt32(buf, -1), nil
   147  	}
   148  
   149  	switch arg := arg.(type) {
   150  	case pgtype.BinaryEncoder:
   151  		sp := len(buf)
   152  		buf = pgio.AppendInt32(buf, -1)
   153  		argBuf, err := arg.EncodeBinary(ci, buf)
   154  		if err != nil {
   155  			return nil, err
   156  		}
   157  		if argBuf != nil {
   158  			buf = argBuf
   159  			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
   160  		}
   161  		return buf, nil
   162  	case pgtype.TextEncoder:
   163  		sp := len(buf)
   164  		buf = pgio.AppendInt32(buf, -1)
   165  		argBuf, err := arg.EncodeText(ci, buf)
   166  		if err != nil {
   167  			return nil, err
   168  		}
   169  		if argBuf != nil {
   170  			buf = argBuf
   171  			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
   172  		}
   173  		return buf, nil
   174  	case string:
   175  		buf = pgio.AppendInt32(buf, int32(len(arg)))
   176  		buf = append(buf, arg...)
   177  		return buf, nil
   178  	}
   179  
   180  	refVal := reflect.ValueOf(arg)
   181  
   182  	if refVal.Kind() == reflect.Ptr {
   183  		if refVal.IsNil() {
   184  			return pgio.AppendInt32(buf, -1), nil
   185  		}
   186  		arg = refVal.Elem().Interface()
   187  		return encodePreparedStatementArgument(ci, buf, oid, arg)
   188  	}
   189  
   190  	if dt, ok := ci.DataTypeForOID(oid); ok {
   191  		value := dt.Value
   192  		err := value.Set(arg)
   193  		if err != nil {
   194  			{
   195  				if arg, ok := arg.(driver.Valuer); ok {
   196  					v, err := callValuerValue(arg)
   197  					if err != nil {
   198  						return nil, err
   199  					}
   200  					return encodePreparedStatementArgument(ci, buf, oid, v)
   201  				}
   202  			}
   203  
   204  			return nil, err
   205  		}
   206  
   207  		sp := len(buf)
   208  		buf = pgio.AppendInt32(buf, -1)
   209  		argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
   210  		if err != nil {
   211  			return nil, err
   212  		}
   213  		if argBuf != nil {
   214  			buf = argBuf
   215  			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
   216  		}
   217  		return buf, nil
   218  	}
   219  
   220  	if strippedArg, ok := stripNamedType(&refVal); ok {
   221  		return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
   222  	}
   223  	return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
   224  }
   225  
   226  // chooseParameterFormatCode determines the correct format code for an
   227  // argument to a prepared statement. It defaults to TextFormatCode if no
   228  // determination can be made.
   229  func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 {
   230  	switch arg := arg.(type) {
   231  	case pgtype.ParamFormatPreferrer:
   232  		return arg.PreferredParamFormat()
   233  	case pgtype.BinaryEncoder:
   234  		return BinaryFormatCode
   235  	case string, *string, pgtype.TextEncoder:
   236  		return TextFormatCode
   237  	}
   238  
   239  	return ci.ParamFormatCodeForOID(oid)
   240  }
   241  
   242  func stripNamedType(val *reflect.Value) (interface{}, bool) {
   243  	switch val.Kind() {
   244  	case reflect.Int:
   245  		convVal := int(val.Int())
   246  		return convVal, reflect.TypeOf(convVal) != val.Type()
   247  	case reflect.Int8:
   248  		convVal := int8(val.Int())
   249  		return convVal, reflect.TypeOf(convVal) != val.Type()
   250  	case reflect.Int16:
   251  		convVal := int16(val.Int())
   252  		return convVal, reflect.TypeOf(convVal) != val.Type()
   253  	case reflect.Int32:
   254  		convVal := int32(val.Int())
   255  		return convVal, reflect.TypeOf(convVal) != val.Type()
   256  	case reflect.Int64:
   257  		convVal := int64(val.Int())
   258  		return convVal, reflect.TypeOf(convVal) != val.Type()
   259  	case reflect.Uint:
   260  		convVal := uint(val.Uint())
   261  		return convVal, reflect.TypeOf(convVal) != val.Type()
   262  	case reflect.Uint8:
   263  		convVal := uint8(val.Uint())
   264  		return convVal, reflect.TypeOf(convVal) != val.Type()
   265  	case reflect.Uint16:
   266  		convVal := uint16(val.Uint())
   267  		return convVal, reflect.TypeOf(convVal) != val.Type()
   268  	case reflect.Uint32:
   269  		convVal := uint32(val.Uint())
   270  		return convVal, reflect.TypeOf(convVal) != val.Type()
   271  	case reflect.Uint64:
   272  		convVal := uint64(val.Uint())
   273  		return convVal, reflect.TypeOf(convVal) != val.Type()
   274  	case reflect.String:
   275  		convVal := val.String()
   276  		return convVal, reflect.TypeOf(convVal) != val.Type()
   277  	}
   278  
   279  	return nil, false
   280  }
   281  

View as plain text