1 // Copyright 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package pgxv4 provides a Cloud SQL Postgres driver that uses pgx v4 and works 16 // with the database/sql package. 17 package pgxv4 18 19 import ( 20 "context" 21 "database/sql" 22 "database/sql/driver" 23 "net" 24 "sync" 25 26 "cloud.google.com/go/cloudsqlconn" 27 "github.com/jackc/pgx/v4" 28 "github.com/jackc/pgx/v4/stdlib" 29 ) 30 31 // RegisterDriver registers a Postgres driver that uses the cloudsqlconn.Dialer 32 // configured with the provided options. The choice of name is entirely up to 33 // the caller and may be used to distinguish between multiple registrations of 34 // differently configured Dialers. The driver uses pgx/v4 internally. 35 // RegisterDriver returns a cleanup function that should be called one the 36 // database connection is no longer needed. 37 func RegisterDriver(name string, opts ...cloudsqlconn.Option) (func() error, error) { 38 d, err := cloudsqlconn.NewDialer(context.Background(), opts...) 39 if err != nil { 40 return func() error { return nil }, err 41 } 42 sql.Register(name, &pgDriver{ 43 d: d, 44 dbURIs: make(map[string]string), 45 }) 46 return func() error { return d.Close() }, nil 47 } 48 49 type pgDriver struct { 50 d *cloudsqlconn.Dialer 51 mu sync.RWMutex 52 // dbURIs is a map of DSN to DB URI for registered connection names. 53 dbURIs map[string]string 54 } 55 56 // Open accepts a keyword/value formatted connection string and returns a 57 // connection to the database using cloudsqlconn.Dialer. The Cloud SQL instance 58 // connection name should be specified in the host field. For example: 59 // 60 // "host=my-project:us-central1:my-db-instance user=myuser password=mypass" 61 func (p *pgDriver) Open(name string) (driver.Conn, error) { 62 dbURI, err := p.dbURI(name) 63 if err != nil { 64 return nil, err 65 } 66 return stdlib.GetDefaultDriver().Open(dbURI) 67 68 } 69 70 // dbURI registers a driver using the provided DSN. If the name has already 71 // been registered, dbURI returns the existing registration. 72 func (p *pgDriver) dbURI(name string) (string, error) { 73 p.mu.Lock() 74 defer p.mu.Unlock() 75 dbURI, ok := p.dbURIs[name] 76 if ok { 77 return dbURI, nil 78 } 79 80 config, err := pgx.ParseConfig(name) 81 if err != nil { 82 return "", err 83 } 84 instConnName := config.Config.Host // Extract instance connection name 85 config.Config.Host = "localhost" // Replace it with a default value 86 config.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) { 87 return p.d.Dial(ctx, instConnName) 88 } 89 90 dbURI = stdlib.RegisterConnConfig(config) 91 p.dbURIs[name] = dbURI 92 93 return dbURI, nil 94 } 95