1
2 package tracelog
3
4 import (
5 "context"
6 "encoding/hex"
7 "errors"
8 "fmt"
9 "time"
10 "unicode/utf8"
11
12 "github.com/jackc/pgx/v5"
13 )
14
15
16
17 type LogLevel int
18
19
20
21 const (
22 LogLevelTrace = LogLevel(6)
23 LogLevelDebug = LogLevel(5)
24 LogLevelInfo = LogLevel(4)
25 LogLevelWarn = LogLevel(3)
26 LogLevelError = LogLevel(2)
27 LogLevelNone = LogLevel(1)
28 )
29
30 func (ll LogLevel) String() string {
31 switch ll {
32 case LogLevelTrace:
33 return "trace"
34 case LogLevelDebug:
35 return "debug"
36 case LogLevelInfo:
37 return "info"
38 case LogLevelWarn:
39 return "warn"
40 case LogLevelError:
41 return "error"
42 case LogLevelNone:
43 return "none"
44 default:
45 return fmt.Sprintf("invalid level %d", ll)
46 }
47 }
48
49
50 type Logger interface {
51
52 Log(ctx context.Context, level LogLevel, msg string, data map[string]any)
53 }
54
55
56 type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
57
58
59 func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) {
60 f(ctx, level, msg, data)
61 }
62
63
64
65
66
67
68
69
70
71
72
73 func LogLevelFromString(s string) (LogLevel, error) {
74 switch s {
75 case "trace":
76 return LogLevelTrace, nil
77 case "debug":
78 return LogLevelDebug, nil
79 case "info":
80 return LogLevelInfo, nil
81 case "warn":
82 return LogLevelWarn, nil
83 case "error":
84 return LogLevelError, nil
85 case "none":
86 return LogLevelNone, nil
87 default:
88 return 0, errors.New("invalid log level")
89 }
90 }
91
92 func logQueryArgs(args []any) []any {
93 logArgs := make([]any, 0, len(args))
94
95 for _, a := range args {
96 switch v := a.(type) {
97 case []byte:
98 if len(v) < 64 {
99 a = hex.EncodeToString(v)
100 } else {
101 a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
102 }
103 case string:
104 if len(v) > 64 {
105 var l int = 0
106 for w := 0; l < 64; l += w {
107 _, w = utf8.DecodeRuneInString(v[l:])
108 }
109 if len(v) > l {
110 a = fmt.Sprintf("%s (truncated %d bytes)", v[:l], len(v)-l)
111 }
112 }
113 }
114 logArgs = append(logArgs, a)
115 }
116
117 return logArgs
118 }
119
120
121
122 type TraceLog struct {
123 Logger Logger
124 LogLevel LogLevel
125 }
126
127 type ctxKey int
128
129 const (
130 _ ctxKey = iota
131 tracelogQueryCtxKey
132 tracelogBatchCtxKey
133 tracelogCopyFromCtxKey
134 tracelogConnectCtxKey
135 tracelogPrepareCtxKey
136 )
137
138 type traceQueryData struct {
139 startTime time.Time
140 sql string
141 args []any
142 }
143
144 func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
145 return context.WithValue(ctx, tracelogQueryCtxKey, &traceQueryData{
146 startTime: time.Now(),
147 sql: data.SQL,
148 args: data.Args,
149 })
150 }
151
152 func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
153 queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData)
154
155 endTime := time.Now()
156 interval := endTime.Sub(queryData.startTime)
157
158 if data.Err != nil {
159 if tl.shouldLog(LogLevelError) {
160 tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval})
161 }
162 return
163 }
164
165 if tl.shouldLog(LogLevelInfo) {
166 tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()})
167 }
168 }
169
170 type traceBatchData struct {
171 startTime time.Time
172 }
173
174 func (tl *TraceLog) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
175 return context.WithValue(ctx, tracelogBatchCtxKey, &traceBatchData{
176 startTime: time.Now(),
177 })
178 }
179
180 func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
181 if data.Err != nil {
182 if tl.shouldLog(LogLevelError) {
183 tl.log(ctx, conn, LogLevelError, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "err": data.Err})
184 }
185 return
186 }
187
188 if tl.shouldLog(LogLevelInfo) {
189 tl.log(ctx, conn, LogLevelInfo, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "commandTag": data.CommandTag.String()})
190 }
191 }
192
193 func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
194 queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData)
195
196 endTime := time.Now()
197 interval := endTime.Sub(queryData.startTime)
198
199 if data.Err != nil {
200 if tl.shouldLog(LogLevelError) {
201 tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval})
202 }
203 return
204 }
205
206 if tl.shouldLog(LogLevelInfo) {
207 tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval})
208 }
209 }
210
211 type traceCopyFromData struct {
212 startTime time.Time
213 TableName pgx.Identifier
214 ColumnNames []string
215 }
216
217 func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
218 return context.WithValue(ctx, tracelogCopyFromCtxKey, &traceCopyFromData{
219 startTime: time.Now(),
220 TableName: data.TableName,
221 ColumnNames: data.ColumnNames,
222 })
223 }
224
225 func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
226 copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData)
227
228 endTime := time.Now()
229 interval := endTime.Sub(copyFromData.startTime)
230
231 if data.Err != nil {
232 if tl.shouldLog(LogLevelError) {
233 tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval})
234 }
235 return
236 }
237
238 if tl.shouldLog(LogLevelInfo) {
239 tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()})
240 }
241 }
242
243 type traceConnectData struct {
244 startTime time.Time
245 connConfig *pgx.ConnConfig
246 }
247
248 func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
249 return context.WithValue(ctx, tracelogConnectCtxKey, &traceConnectData{
250 startTime: time.Now(),
251 connConfig: data.ConnConfig,
252 })
253 }
254
255 func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
256 connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData)
257
258 endTime := time.Now()
259 interval := endTime.Sub(connectData.startTime)
260
261 if data.Err != nil {
262 if tl.shouldLog(LogLevelError) {
263 tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{
264 "host": connectData.connConfig.Host,
265 "port": connectData.connConfig.Port,
266 "database": connectData.connConfig.Database,
267 "time": interval,
268 "err": data.Err,
269 })
270 }
271 return
272 }
273
274 if data.Conn != nil {
275 if tl.shouldLog(LogLevelInfo) {
276 tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{
277 "host": connectData.connConfig.Host,
278 "port": connectData.connConfig.Port,
279 "database": connectData.connConfig.Database,
280 "time": interval,
281 })
282 }
283 }
284 }
285
286 type tracePrepareData struct {
287 startTime time.Time
288 name string
289 sql string
290 }
291
292 func (tl *TraceLog) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
293 return context.WithValue(ctx, tracelogPrepareCtxKey, &tracePrepareData{
294 startTime: time.Now(),
295 name: data.Name,
296 sql: data.SQL,
297 })
298 }
299
300 func (tl *TraceLog) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
301 prepareData := ctx.Value(tracelogPrepareCtxKey).(*tracePrepareData)
302
303 endTime := time.Now()
304 interval := endTime.Sub(prepareData.startTime)
305
306 if data.Err != nil {
307 if tl.shouldLog(LogLevelError) {
308 tl.log(ctx, conn, LogLevelError, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "err": data.Err, "time": interval})
309 }
310 return
311 }
312
313 if tl.shouldLog(LogLevelInfo) {
314 tl.log(ctx, conn, LogLevelInfo, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "time": interval, "alreadyPrepared": data.AlreadyPrepared})
315 }
316 }
317
318 func (tl *TraceLog) shouldLog(lvl LogLevel) bool {
319 return tl.LogLevel >= lvl
320 }
321
322 func (tl *TraceLog) log(ctx context.Context, conn *pgx.Conn, lvl LogLevel, msg string, data map[string]any) {
323 if data == nil {
324 data = map[string]any{}
325 }
326
327 pgConn := conn.PgConn()
328 if pgConn != nil {
329 pid := pgConn.PID()
330 if pid != 0 {
331 data["pid"] = pid
332 }
333 }
334
335 tl.Logger.Log(ctx, lvl, msg, data)
336 }
337
View as plain text