1 package postgres
2
3 import (
4 "context"
5 "fmt"
6 "net/url"
7 "regexp"
8 "strings"
9
10 "github.com/go-logr/logr"
11 "github.com/jackc/pgx/v5"
12 "github.com/jackc/pgx/v5/pgconn"
13 "github.com/jackc/pgx/v5/pgxpool"
14 )
15
16 const (
17 commandInsert = "insert"
18 commandSelect = "select"
19 )
20
21 type defaultClient struct {
22 ctx context.Context
23 logger logr.Logger
24
25 pool *pgxpool.Pool
26 dialFunc pgconn.DialFunc
27 }
28
29
30 var _ Client = &defaultClient{}
31
32 type defaultRows struct {
33 pgx.Rows
34 }
35
36
37 var _ Rows = &defaultRows{}
38
39 func (r *defaultRows) Close() error {
40 r.Rows.Close()
41 return nil
42 }
43
44
45
46
47
48 func New(ctx context.Context, logger logr.Logger, opts DSNOptions) (Client, error) {
49 l := logger.WithValues("sql", "postgres")
50
51 dsn := opts.ToString()
52 config, err := pgxpool.ParseConfig(dsn)
53 if err != nil {
54 return nil, err
55 }
56 l.Info("using dsn config", "connString", redactPW(config.ConnString()))
57
58 pool, err := pgxpool.NewWithConfig(ctx, config)
59 if err != nil {
60 return nil, err
61 }
62
63 return &defaultClient{
64 ctx: ctx,
65 logger: l,
66 pool: pool,
67 }, nil
68 }
69
70
71
72
73 func NewWithDialer(ctx context.Context, logger logr.Logger, opts DSNOptions, dialFunc pgconn.DialFunc) (Client, error) {
74 l := logger.WithValues("sql", "postgres")
75
76 dsn := opts.ToString()
77 config, err := pgxpool.ParseConfig(dsn)
78 if err != nil {
79 return nil, err
80 }
81 l.Info("using dsn config", "connString", redactPW(config.ConnString()))
82
83
84 config.ConnConfig.DialFunc = dialFunc
85
86 pool, err := pgxpool.NewWithConfig(ctx, config)
87 if err != nil {
88 return nil, err
89 }
90
91 return &defaultClient{
92 ctx: ctx,
93 logger: l,
94 pool: pool,
95 dialFunc: dialFunc,
96 }, nil
97 }
98
99
100 func (c *defaultClient) Connect() error {
101 return c.pool.Ping(c.ctx)
102 }
103
104
105
106 func (c *defaultClient) Close() error {
107 c.pool.Close()
108 return nil
109 }
110
111
112 func (c *defaultClient) IsConnected() bool {
113 if c.pool == nil {
114 return false
115 }
116 return c.pool.Ping(c.ctx) == nil
117 }
118
119
120 func (c *defaultClient) Insert(statement string, args ...interface{}) error {
121 s := strings.TrimSpace(strings.ToLower(statement))
122 if !strings.HasPrefix(s, commandInsert) {
123 return fmt.Errorf("invalid sql command")
124 }
125 if _, err := c.pool.Exec(c.ctx, statement, args...); err != nil {
126 return err
127 }
128 return nil
129 }
130
131
132 func (c *defaultClient) Query(statement string, args ...interface{}) (Rows, error) {
133 s := strings.TrimSpace(strings.ToLower(statement))
134 if !strings.HasPrefix(s, commandSelect) {
135 return nil, fmt.Errorf("invalid sql command")
136 }
137 rows, err := c.pool.Query(c.ctx, statement, args...)
138 if err != nil {
139 return nil, err
140 }
141 return &defaultRows{rows}, nil
142 }
143
144 func redactPW(connString string) string {
145 if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
146 if u, err := url.Parse(connString); err == nil {
147 return redactURL(u)
148 }
149 }
150 quotedDSN := regexp.MustCompile(`password='[^']*'`)
151 connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
152 plainDSN := regexp.MustCompile(`password=[^ ]*`)
153 connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
154 brokenURL := regexp.MustCompile(`:[^:@]+?@`)
155 connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
156 return connString
157 }
158
159 func redactURL(u *url.URL) string {
160 if u == nil {
161 return ""
162 }
163 if _, pwSet := u.User.Password(); pwSet {
164 u.User = url.UserPassword(u.User.Username(), "xxxxx")
165 }
166 return u.String()
167 }
168
View as plain text