package cloudsql import ( "context" "database/sql" "fmt" "net" "strings" "time" "cloud.google.com/go/cloudsqlconn" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" ) const DefaultMaxOpenConns = 20 type EdgePostgres struct { connectionName string dbName string dialer *cloudsqlconn.Dialer host string maxOpenConns int password string port string username string searchPath []string } // starting point for creating a gcp cloud sql postgres connection func GCPPostgresConnection(connectionName string) *EdgePostgres { c := &EdgePostgres{} c.connectionName = connectionName return c } // starting point for creating a regular postgres connection func PostgresConnection(host, port string) *EdgePostgres { c := &EdgePostgres{} c.host = host c.port = port return c } // MaxOpenConns limits the amount of database connections that can be opened by the sql client. // If `count <= 0` then the database client can open an unlimited number of connections. // When MaxOpenConns is not called, cloudsql sets the max open connections to DefaultMaxOpenConns. func (c *EdgePostgres) MaxOpenConns(count int) *EdgePostgres { if count <= 0 { // Any value less than or equal to 0 means unlimited connections. // Setting maxConnections to -1 lets us know the user desires unlimited connections, and not the default value. c.maxOpenConns = -1 } else { c.maxOpenConns = count } return c } func (c *EdgePostgres) Username(username string) *EdgePostgres { c.username = username return c } // Optionally set the search_path for the connection. Uses the default search_path // if not set. func (c *EdgePostgres) SearchPath(searchPath ...string) *EdgePostgres { c.searchPath = searchPath return c } func (c *EdgePostgres) Password(password string) *EdgePostgres { c.password = password return c } func (c *EdgePostgres) DBName(name string) *EdgePostgres { c.dbName = name return c } // SetDialer sets the custom database dialer. func (c *EdgePostgres) SetDialer(dialer *cloudsqlconn.Dialer) *EdgePostgres { c.dialer = dialer return c } // AttachDialer sets the database dialer to be a gcp cloudsql dialer. func (c *EdgePostgres) AttachDialer(ctx context.Context) (*EdgePostgres, error) { dialer, err := cloudsqlconn.NewDialer(ctx, cloudsqlconn.WithIAMAuthN()) if err != nil { return nil, err } c.dialer = dialer return c, nil } // Dial implements the dialer interface. func (c *EdgePostgres) Dial(_, _ string) (net.Conn, error) { return c.dialer.Dial(context.Background(), c.connectionName) } // DialTimeout implements the dialer interface. func (c *EdgePostgres) DialTimeout(_, _ string, timeout time.Duration) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return c.dialer.Dial(ctx, c.connectionName) } func (c *EdgePostgres) Validate() error { if c.connectionName != "" && c.host != "" { return fmt.Errorf("unable to set both connection name and host, use connection name for gcp connection" + " and host for standard postgres connetion") } if c.connectionName == "" && c.host == "" { return fmt.Errorf("must set connection name or host, use connection name for gcp connection" + " and host for standard postgres connetion") } if c.host != "" && c.port == "" { return fmt.Errorf("port must be set for standard db connection") } if c.host != "" && c.password == "" { return fmt.Errorf("password must be set for standard db connection") } if c.username == "" { return fmt.Errorf("must set username") } if c.dbName == "" { return fmt.Errorf("must set db name") } if c.maxOpenConns == 0 { c.maxOpenConns = DefaultMaxOpenConns } return nil } func (c *EdgePostgres) NewConnection() (*sql.DB, error) { if err := c.Validate(); err != nil { return nil, err } config, err := c.CreateConfig() if err != nil { return nil, err } dbURI := stdlib.RegisterConnConfig(config) dbPool, err := sql.Open("pgx", dbURI) if err != nil { return nil, fmt.Errorf("sql.Open: %v", err) } dbPool.SetMaxOpenConns(c.maxOpenConns) return dbPool, nil } func (c *EdgePostgres) CreateConfig() (*pgx.ConnConfig, error) { if c.connectionName != "" { return c.buildGCPConfig() } return c.buildPostgresConfig() } // ConnectionString returns a PG compatible connection string from the database config. func (c *EdgePostgres) ConnectionString(isIAM bool) string { var connString string if isIAM { connString = fmt.Sprintf("host=%s user=%s database=%s sslmode=disable", c.connectionName, c.username, c.dbName) } else { connString = fmt.Sprintf("host=%s user=%s database=%s password=%s port=%s sslmode=disable", c.host, c.username, c.dbName, c.password, c.port) } if len(c.searchPath) != 0 { connString = fmt.Sprintf("%s search_path='%s'", connString, strings.Join(c.searchPath, ", ")) } return connString } func (c *EdgePostgres) buildGCPConfig() (*pgx.ConnConfig, error) { var d *cloudsqlconn.Dialer var dsn string var err error if c.password == "" { dsn = fmt.Sprintf("database=%s user=%s sslmode=disable", c.dbName, c.username) d, err = cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) } else { dsn = fmt.Sprintf("database=%s user=%s sslmode=disable password=%s ", c.dbName, c.username, c.password) d, err = cloudsqlconn.NewDialer(context.Background()) } if err != nil { return nil, err } if len(c.searchPath) != 0 { dsn = fmt.Sprintf("%s search_path='%s'", dsn, strings.Join(c.searchPath, ", ")) } config, err := pgx.ParseConfig(dsn) if err != nil { return nil, err } config.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) { return d.Dial(ctx, c.connectionName) } return config, err } func (c *EdgePostgres) buildPostgresConfig() (*pgx.ConnConfig, error) { dsn := c.ConnectionString(false) return pgx.ParseConfig(dsn) }