...

Source file src/edge-infra.dev/pkg/lib/gcp/cloudsql/cloudsql.go

Documentation: edge-infra.dev/pkg/lib/gcp/cloudsql

     1  package cloudsql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"net"
     8  	"strings"
     9  	"time"
    10  
    11  	"cloud.google.com/go/cloudsqlconn"
    12  	"github.com/jackc/pgx/v4"
    13  	"github.com/jackc/pgx/v4/stdlib"
    14  )
    15  
    16  const DefaultMaxOpenConns = 20
    17  
    18  type EdgePostgres struct {
    19  	connectionName string
    20  	dbName         string
    21  	dialer         *cloudsqlconn.Dialer
    22  	host           string
    23  	maxOpenConns   int
    24  	password       string
    25  	port           string
    26  	username       string
    27  	searchPath     []string
    28  }
    29  
    30  // starting point for creating a gcp cloud sql postgres connection
    31  func GCPPostgresConnection(connectionName string) *EdgePostgres {
    32  	c := &EdgePostgres{}
    33  	c.connectionName = connectionName
    34  	return c
    35  }
    36  
    37  // starting point for creating a regular postgres connection
    38  func PostgresConnection(host, port string) *EdgePostgres {
    39  	c := &EdgePostgres{}
    40  	c.host = host
    41  	c.port = port
    42  	return c
    43  }
    44  
    45  // MaxOpenConns limits the amount of database connections that can be opened by the sql client.
    46  // If `count <= 0` then the database client can open an unlimited number of connections.
    47  // When MaxOpenConns is not called, cloudsql sets the max open connections to DefaultMaxOpenConns.
    48  func (c *EdgePostgres) MaxOpenConns(count int) *EdgePostgres {
    49  	if count <= 0 {
    50  		// Any value less than or equal to 0 means unlimited connections.
    51  		// Setting maxConnections to -1 lets us know the user desires unlimited connections, and not the default value.
    52  		c.maxOpenConns = -1
    53  	} else {
    54  		c.maxOpenConns = count
    55  	}
    56  	return c
    57  }
    58  
    59  func (c *EdgePostgres) Username(username string) *EdgePostgres {
    60  	c.username = username
    61  	return c
    62  }
    63  
    64  // Optionally set the search_path for the connection. Uses the default search_path
    65  // if not set.
    66  func (c *EdgePostgres) SearchPath(searchPath ...string) *EdgePostgres {
    67  	c.searchPath = searchPath
    68  	return c
    69  }
    70  
    71  func (c *EdgePostgres) Password(password string) *EdgePostgres {
    72  	c.password = password
    73  	return c
    74  }
    75  
    76  func (c *EdgePostgres) DBName(name string) *EdgePostgres {
    77  	c.dbName = name
    78  	return c
    79  }
    80  
    81  // SetDialer sets the custom database dialer.
    82  func (c *EdgePostgres) SetDialer(dialer *cloudsqlconn.Dialer) *EdgePostgres {
    83  	c.dialer = dialer
    84  	return c
    85  }
    86  
    87  // AttachDialer sets the database dialer to be a gcp cloudsql dialer.
    88  func (c *EdgePostgres) AttachDialer(ctx context.Context) (*EdgePostgres, error) {
    89  	dialer, err := cloudsqlconn.NewDialer(ctx, cloudsqlconn.WithIAMAuthN())
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  	c.dialer = dialer
    94  	return c, nil
    95  }
    96  
    97  // Dial implements the dialer interface.
    98  func (c *EdgePostgres) Dial(_, _ string) (net.Conn, error) {
    99  	return c.dialer.Dial(context.Background(), c.connectionName)
   100  }
   101  
   102  // DialTimeout implements the dialer interface.
   103  func (c *EdgePostgres) DialTimeout(_, _ string, timeout time.Duration) (net.Conn, error) {
   104  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   105  	defer cancel()
   106  	return c.dialer.Dial(ctx, c.connectionName)
   107  }
   108  
   109  func (c *EdgePostgres) Validate() error {
   110  	if c.connectionName != "" && c.host != "" {
   111  		return fmt.Errorf("unable to set both connection name and host, use connection name for gcp connection" +
   112  			" and host for standard postgres connetion")
   113  	}
   114  	if c.connectionName == "" && c.host == "" {
   115  		return fmt.Errorf("must set connection name or host, use connection name for gcp connection" +
   116  			" and host for standard postgres connetion")
   117  	}
   118  	if c.host != "" && c.port == "" {
   119  		return fmt.Errorf("port must be set for standard db connection")
   120  	}
   121  	if c.host != "" && c.password == "" {
   122  		return fmt.Errorf("password must be set for standard db connection")
   123  	}
   124  	if c.username == "" {
   125  		return fmt.Errorf("must set username")
   126  	}
   127  	if c.dbName == "" {
   128  		return fmt.Errorf("must set db name")
   129  	}
   130  	if c.maxOpenConns == 0 {
   131  		c.maxOpenConns = DefaultMaxOpenConns
   132  	}
   133  	return nil
   134  }
   135  
   136  func (c *EdgePostgres) NewConnection() (*sql.DB, error) {
   137  	if err := c.Validate(); err != nil {
   138  		return nil, err
   139  	}
   140  	config, err := c.CreateConfig()
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	dbURI := stdlib.RegisterConnConfig(config)
   145  	dbPool, err := sql.Open("pgx", dbURI)
   146  	if err != nil {
   147  		return nil, fmt.Errorf("sql.Open: %v", err)
   148  	}
   149  	dbPool.SetMaxOpenConns(c.maxOpenConns)
   150  	return dbPool, nil
   151  }
   152  
   153  func (c *EdgePostgres) CreateConfig() (*pgx.ConnConfig, error) {
   154  	if c.connectionName != "" {
   155  		return c.buildGCPConfig()
   156  	}
   157  	return c.buildPostgresConfig()
   158  }
   159  
   160  // ConnectionString returns a PG compatible connection string from the database config.
   161  func (c *EdgePostgres) ConnectionString(isIAM bool) string {
   162  	var connString string
   163  	if isIAM {
   164  		connString = fmt.Sprintf("host=%s user=%s database=%s sslmode=disable", c.connectionName, c.username, c.dbName)
   165  	} else {
   166  		connString = fmt.Sprintf("host=%s user=%s database=%s password=%s port=%s sslmode=disable", c.host, c.username, c.dbName, c.password, c.port)
   167  	}
   168  
   169  	if len(c.searchPath) != 0 {
   170  		connString = fmt.Sprintf("%s search_path='%s'", connString, strings.Join(c.searchPath, ", "))
   171  	}
   172  
   173  	return connString
   174  }
   175  
   176  func (c *EdgePostgres) buildGCPConfig() (*pgx.ConnConfig, error) {
   177  	var d *cloudsqlconn.Dialer
   178  	var dsn string
   179  	var err error
   180  
   181  	if c.password == "" {
   182  		dsn = fmt.Sprintf("database=%s user=%s sslmode=disable", c.dbName, c.username)
   183  		d, err = cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN())
   184  	} else {
   185  		dsn = fmt.Sprintf("database=%s user=%s sslmode=disable password=%s ", c.dbName, c.username, c.password)
   186  		d, err = cloudsqlconn.NewDialer(context.Background())
   187  	}
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  
   192  	if len(c.searchPath) != 0 {
   193  		dsn = fmt.Sprintf("%s search_path='%s'", dsn, strings.Join(c.searchPath, ", "))
   194  	}
   195  
   196  	config, err := pgx.ParseConfig(dsn)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  	config.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) {
   201  		return d.Dial(ctx, c.connectionName)
   202  	}
   203  	return config, err
   204  }
   205  
   206  func (c *EdgePostgres) buildPostgresConfig() (*pgx.ConnConfig, error) {
   207  	dsn := c.ConnectionString(false)
   208  	return pgx.ParseConfig(dsn)
   209  }
   210  

View as plain text