...

Source file src/github.com/jackc/pgtype/pgxtype/pgxtype.go

Documentation: github.com/jackc/pgtype/pgxtype

     1  package pgxtype
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  
     7  	"github.com/jackc/pgconn"
     8  	"github.com/jackc/pgtype"
     9  	"github.com/jackc/pgx/v4"
    10  )
    11  
    12  type Querier interface {
    13  	Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
    14  	Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
    15  	QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
    16  }
    17  
    18  // LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for
    19  // registration on ci.
    20  func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) {
    21  	var oid uint32
    22  
    23  	err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
    24  	if err != nil {
    25  		return pgtype.DataType{}, err
    26  	}
    27  
    28  	var typtype string
    29  
    30  	err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
    31  	if err != nil {
    32  		return pgtype.DataType{}, err
    33  	}
    34  
    35  	switch typtype {
    36  	case "b": // array
    37  		elementOID, err := GetArrayElementOID(ctx, conn, oid)
    38  		if err != nil {
    39  			return pgtype.DataType{}, err
    40  		}
    41  
    42  		var element pgtype.ValueTranscoder
    43  		if dt, ok := ci.DataTypeForOID(elementOID); ok {
    44  			if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok {
    45  				return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder")
    46  			}
    47  		}
    48  
    49  		newElement := func() pgtype.ValueTranscoder {
    50  			return pgtype.NewValue(element).(pgtype.ValueTranscoder)
    51  		}
    52  
    53  		at := pgtype.NewArrayType(typeName, elementOID, newElement)
    54  		return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil
    55  	case "c": // composite
    56  		fields, err := GetCompositeFields(ctx, conn, oid)
    57  		if err != nil {
    58  			return pgtype.DataType{}, err
    59  		}
    60  		ct, err := pgtype.NewCompositeType(typeName, fields, ci)
    61  		if err != nil {
    62  			return pgtype.DataType{}, err
    63  		}
    64  		return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil
    65  	case "e": // enum
    66  		members, err := GetEnumMembers(ctx, conn, oid)
    67  		if err != nil {
    68  			return pgtype.DataType{}, err
    69  		}
    70  		return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil
    71  	default:
    72  		return pgtype.DataType{}, errors.New("unknown typtype")
    73  	}
    74  }
    75  
    76  func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) {
    77  	var typelem uint32
    78  
    79  	err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
    80  	if err != nil {
    81  		return 0, err
    82  	}
    83  
    84  	return typelem, nil
    85  }
    86  
    87  // GetCompositeFields gets the fields of a composite type.
    88  func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) {
    89  	var typrelid uint32
    90  
    91  	err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	var fields []pgtype.CompositeTypeField
    97  
    98  	rows, err := conn.Query(ctx, `select attname, atttypid
    99  from pg_attribute
   100  where attrelid=$1
   101  order by attnum`, typrelid)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	for rows.Next() {
   107  		var f pgtype.CompositeTypeField
   108  		err := rows.Scan(&f.Name, &f.OID)
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		fields = append(fields, f)
   113  	}
   114  
   115  	if rows.Err() != nil {
   116  		return nil, rows.Err()
   117  	}
   118  
   119  	return fields, nil
   120  }
   121  
   122  // GetEnumMembers gets the possible values of the enum by oid.
   123  func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) {
   124  	members := []string{}
   125  
   126  	rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	for rows.Next() {
   132  		var m string
   133  		err := rows.Scan(&m)
   134  		if err != nil {
   135  			return nil, err
   136  		}
   137  		members = append(members, m)
   138  	}
   139  
   140  	if rows.Err() != nil {
   141  		return nil, rows.Err()
   142  	}
   143  
   144  	return members, nil
   145  }
   146  

View as plain text