1 package cloudsql
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "net"
8 "strings"
9 "time"
10
11 "cloud.google.com/go/cloudsqlconn"
12 "github.com/jackc/pgx/v4"
13 "github.com/jackc/pgx/v4/stdlib"
14 )
15
16 const DefaultMaxOpenConns = 20
17
18 type EdgePostgres struct {
19 connectionName string
20 dbName string
21 dialer *cloudsqlconn.Dialer
22 host string
23 maxOpenConns int
24 password string
25 port string
26 username string
27 searchPath []string
28 }
29
30
31 func GCPPostgresConnection(connectionName string) *EdgePostgres {
32 c := &EdgePostgres{}
33 c.connectionName = connectionName
34 return c
35 }
36
37
38 func PostgresConnection(host, port string) *EdgePostgres {
39 c := &EdgePostgres{}
40 c.host = host
41 c.port = port
42 return c
43 }
44
45
46
47
48 func (c *EdgePostgres) MaxOpenConns(count int) *EdgePostgres {
49 if count <= 0 {
50
51
52 c.maxOpenConns = -1
53 } else {
54 c.maxOpenConns = count
55 }
56 return c
57 }
58
59 func (c *EdgePostgres) Username(username string) *EdgePostgres {
60 c.username = username
61 return c
62 }
63
64
65
66 func (c *EdgePostgres) SearchPath(searchPath ...string) *EdgePostgres {
67 c.searchPath = searchPath
68 return c
69 }
70
71 func (c *EdgePostgres) Password(password string) *EdgePostgres {
72 c.password = password
73 return c
74 }
75
76 func (c *EdgePostgres) DBName(name string) *EdgePostgres {
77 c.dbName = name
78 return c
79 }
80
81
82 func (c *EdgePostgres) SetDialer(dialer *cloudsqlconn.Dialer) *EdgePostgres {
83 c.dialer = dialer
84 return c
85 }
86
87
88 func (c *EdgePostgres) AttachDialer(ctx context.Context) (*EdgePostgres, error) {
89 dialer, err := cloudsqlconn.NewDialer(ctx, cloudsqlconn.WithIAMAuthN())
90 if err != nil {
91 return nil, err
92 }
93 c.dialer = dialer
94 return c, nil
95 }
96
97
98 func (c *EdgePostgres) Dial(_, _ string) (net.Conn, error) {
99 return c.dialer.Dial(context.Background(), c.connectionName)
100 }
101
102
103 func (c *EdgePostgres) DialTimeout(_, _ string, timeout time.Duration) (net.Conn, error) {
104 ctx, cancel := context.WithTimeout(context.Background(), timeout)
105 defer cancel()
106 return c.dialer.Dial(ctx, c.connectionName)
107 }
108
109 func (c *EdgePostgres) Validate() error {
110 if c.connectionName != "" && c.host != "" {
111 return fmt.Errorf("unable to set both connection name and host, use connection name for gcp connection" +
112 " and host for standard postgres connetion")
113 }
114 if c.connectionName == "" && c.host == "" {
115 return fmt.Errorf("must set connection name or host, use connection name for gcp connection" +
116 " and host for standard postgres connetion")
117 }
118 if c.host != "" && c.port == "" {
119 return fmt.Errorf("port must be set for standard db connection")
120 }
121 if c.host != "" && c.password == "" {
122 return fmt.Errorf("password must be set for standard db connection")
123 }
124 if c.username == "" {
125 return fmt.Errorf("must set username")
126 }
127 if c.dbName == "" {
128 return fmt.Errorf("must set db name")
129 }
130 if c.maxOpenConns == 0 {
131 c.maxOpenConns = DefaultMaxOpenConns
132 }
133 return nil
134 }
135
136 func (c *EdgePostgres) NewConnection() (*sql.DB, error) {
137 if err := c.Validate(); err != nil {
138 return nil, err
139 }
140 config, err := c.CreateConfig()
141 if err != nil {
142 return nil, err
143 }
144 dbURI := stdlib.RegisterConnConfig(config)
145 dbPool, err := sql.Open("pgx", dbURI)
146 if err != nil {
147 return nil, fmt.Errorf("sql.Open: %v", err)
148 }
149 dbPool.SetMaxOpenConns(c.maxOpenConns)
150 return dbPool, nil
151 }
152
153 func (c *EdgePostgres) CreateConfig() (*pgx.ConnConfig, error) {
154 if c.connectionName != "" {
155 return c.buildGCPConfig()
156 }
157 return c.buildPostgresConfig()
158 }
159
160
161 func (c *EdgePostgres) ConnectionString(isIAM bool) string {
162 var connString string
163 if isIAM {
164 connString = fmt.Sprintf("host=%s user=%s database=%s sslmode=disable", c.connectionName, c.username, c.dbName)
165 } else {
166 connString = fmt.Sprintf("host=%s user=%s database=%s password=%s port=%s sslmode=disable", c.host, c.username, c.dbName, c.password, c.port)
167 }
168
169 if len(c.searchPath) != 0 {
170 connString = fmt.Sprintf("%s search_path='%s'", connString, strings.Join(c.searchPath, ", "))
171 }
172
173 return connString
174 }
175
176 func (c *EdgePostgres) buildGCPConfig() (*pgx.ConnConfig, error) {
177 var d *cloudsqlconn.Dialer
178 var dsn string
179 var err error
180
181 if c.password == "" {
182 dsn = fmt.Sprintf("database=%s user=%s sslmode=disable", c.dbName, c.username)
183 d, err = cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN())
184 } else {
185 dsn = fmt.Sprintf("database=%s user=%s sslmode=disable password=%s ", c.dbName, c.username, c.password)
186 d, err = cloudsqlconn.NewDialer(context.Background())
187 }
188 if err != nil {
189 return nil, err
190 }
191
192 if len(c.searchPath) != 0 {
193 dsn = fmt.Sprintf("%s search_path='%s'", dsn, strings.Join(c.searchPath, ", "))
194 }
195
196 config, err := pgx.ParseConfig(dsn)
197 if err != nil {
198 return nil, err
199 }
200 config.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) {
201 return d.Dial(ctx, c.connectionName)
202 }
203 return config, err
204 }
205
206 func (c *EdgePostgres) buildPostgresConfig() (*pgx.ConnConfig, error) {
207 dsn := c.ConnectionString(false)
208 return pgx.ParseConfig(dsn)
209 }
210
View as plain text