...

Source file src/github.com/jackc/pgtype/ext/gofrs-uuid/uuid.go

Documentation: github.com/jackc/pgtype/ext/gofrs-uuid

     1  package uuid
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/gofrs/uuid"
     9  	"github.com/jackc/pgtype"
    10  )
    11  
    12  var errUndefined = errors.New("cannot encode status undefined")
    13  var errBadStatus = errors.New("invalid status")
    14  
    15  type UUID struct {
    16  	UUID   uuid.UUID
    17  	Status pgtype.Status
    18  }
    19  
    20  func (dst *UUID) Set(src interface{}) error {
    21  	if src == nil {
    22  		*dst = UUID{Status: pgtype.Null}
    23  		return nil
    24  	}
    25  
    26  	if value, ok := src.(interface{ Get() interface{} }); ok {
    27  		value2 := value.Get()
    28  		if value2 != value {
    29  			return dst.Set(value2)
    30  		}
    31  	}
    32  
    33  	switch value := src.(type) {
    34  	case uuid.UUID:
    35  		*dst = UUID{UUID: value, Status: pgtype.Present}
    36  	case [16]byte:
    37  		*dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present}
    38  	case []byte:
    39  		if len(value) != 16 {
    40  			return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value))
    41  		}
    42  		*dst = UUID{Status: pgtype.Present}
    43  		copy(dst.UUID[:], value)
    44  	case string:
    45  		uuid, err := uuid.FromString(value)
    46  		if err != nil {
    47  			return err
    48  		}
    49  		*dst = UUID{UUID: uuid, Status: pgtype.Present}
    50  	default:
    51  		// If all else fails see if pgtype.UUID can handle it. If so, translate through that.
    52  		pgUUID := &pgtype.UUID{}
    53  		if err := pgUUID.Set(value); err != nil {
    54  			return fmt.Errorf("cannot convert %v to UUID", value)
    55  		}
    56  
    57  		*dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status}
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (dst UUID) Get() interface{} {
    64  	switch dst.Status {
    65  	case pgtype.Present:
    66  		return dst.UUID
    67  	case pgtype.Null:
    68  		return nil
    69  	default:
    70  		return dst.Status
    71  	}
    72  }
    73  
    74  func (src *UUID) AssignTo(dst interface{}) error {
    75  	switch src.Status {
    76  	case pgtype.Present:
    77  		switch v := dst.(type) {
    78  		case *uuid.UUID:
    79  			*v = src.UUID
    80  			return nil
    81  		case *[16]byte:
    82  			*v = [16]byte(src.UUID)
    83  			return nil
    84  		case *[]byte:
    85  			*v = make([]byte, 16)
    86  			copy(*v, src.UUID[:])
    87  			return nil
    88  		case *string:
    89  			*v = src.UUID.String()
    90  			return nil
    91  		default:
    92  			if nextDst, retry := pgtype.GetAssignToDstType(v); retry {
    93  				return src.AssignTo(nextDst)
    94  			}
    95  			return fmt.Errorf("unable to assign to %T", dst)
    96  		}
    97  	case pgtype.Null:
    98  		return pgtype.NullAssignTo(dst)
    99  	}
   100  
   101  	return fmt.Errorf("cannot assign %v into %T", src, dst)
   102  }
   103  
   104  func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
   105  	if src == nil {
   106  		*dst = UUID{Status: pgtype.Null}
   107  		return nil
   108  	}
   109  
   110  	u, err := uuid.FromString(string(src))
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	*dst = UUID{UUID: u, Status: pgtype.Present}
   116  	return nil
   117  }
   118  
   119  func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
   120  	if src == nil {
   121  		*dst = UUID{Status: pgtype.Null}
   122  		return nil
   123  	}
   124  
   125  	if len(src) != 16 {
   126  		return fmt.Errorf("invalid length for UUID: %v", len(src))
   127  	}
   128  
   129  	*dst = UUID{Status: pgtype.Present}
   130  	copy(dst.UUID[:], src)
   131  	return nil
   132  }
   133  
   134  func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
   135  	switch src.Status {
   136  	case pgtype.Null:
   137  		return nil, nil
   138  	case pgtype.Undefined:
   139  		return nil, errUndefined
   140  	}
   141  
   142  	return append(buf, src.UUID.String()...), nil
   143  }
   144  
   145  func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
   146  	switch src.Status {
   147  	case pgtype.Null:
   148  		return nil, nil
   149  	case pgtype.Undefined:
   150  		return nil, errUndefined
   151  	}
   152  
   153  	return append(buf, src.UUID[:]...), nil
   154  }
   155  
   156  // Scan implements the database/sql Scanner interface.
   157  func (dst *UUID) Scan(src interface{}) error {
   158  	if src == nil {
   159  		*dst = UUID{Status: pgtype.Null}
   160  		return nil
   161  	}
   162  
   163  	switch src := src.(type) {
   164  	case string:
   165  		return dst.DecodeText(nil, []byte(src))
   166  	case []byte:
   167  		return dst.DecodeText(nil, src)
   168  	}
   169  
   170  	return fmt.Errorf("cannot scan %T", src)
   171  }
   172  
   173  // Value implements the database/sql/driver Valuer interface.
   174  func (src UUID) Value() (driver.Value, error) {
   175  	return pgtype.EncodeValueText(src)
   176  }
   177  
   178  func (src UUID) MarshalJSON() ([]byte, error) {
   179  	switch src.Status {
   180  	case pgtype.Present:
   181  		return []byte(`"` + src.UUID.String() + `"`), nil
   182  	case pgtype.Null:
   183  		return []byte("null"), nil
   184  	case pgtype.Undefined:
   185  		return nil, errUndefined
   186  	}
   187  
   188  	return nil, errBadStatus
   189  }
   190  
   191  func (dst *UUID) UnmarshalJSON(b []byte) error {
   192  	u := uuid.NullUUID{}
   193  	err := u.UnmarshalJSON(b)
   194  	if err != nil {
   195  		return err
   196  	}
   197  
   198  	status := pgtype.Null
   199  	if u.Valid {
   200  		status = pgtype.Present
   201  	}
   202  	*dst = UUID{UUID: u.UUID, Status: status}
   203  
   204  	return nil
   205  }
   206  

View as plain text