...

Source file src/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go

Documentation: github.com/jackc/pgx/v5/internal/sanitize

     1  package sanitize
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  	"unicode/utf8"
    11  )
    12  
    13  // Part is either a string or an int. A string is raw SQL. An int is a
    14  // argument placeholder.
    15  type Part any
    16  
    17  type Query struct {
    18  	Parts []Part
    19  }
    20  
    21  // utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
    22  // character. utf8.RuneError is not an error if it is also width 3.
    23  //
    24  // https://github.com/jackc/pgx/issues/1380
    25  const replacementcharacterwidth = 3
    26  
    27  func (q *Query) Sanitize(args ...any) (string, error) {
    28  	argUse := make([]bool, len(args))
    29  	buf := &bytes.Buffer{}
    30  
    31  	for _, part := range q.Parts {
    32  		var str string
    33  		switch part := part.(type) {
    34  		case string:
    35  			str = part
    36  		case int:
    37  			argIdx := part - 1
    38  
    39  			if argIdx < 0 {
    40  				return "", fmt.Errorf("first sql argument must be > 0")
    41  			}
    42  
    43  			if argIdx >= len(args) {
    44  				return "", fmt.Errorf("insufficient arguments")
    45  			}
    46  			arg := args[argIdx]
    47  			switch arg := arg.(type) {
    48  			case nil:
    49  				str = "null"
    50  			case int64:
    51  				str = strconv.FormatInt(arg, 10)
    52  			case float64:
    53  				str = strconv.FormatFloat(arg, 'f', -1, 64)
    54  			case bool:
    55  				str = strconv.FormatBool(arg)
    56  			case []byte:
    57  				str = QuoteBytes(arg)
    58  			case string:
    59  				str = QuoteString(arg)
    60  			case time.Time:
    61  				str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
    62  			default:
    63  				return "", fmt.Errorf("invalid arg type: %T", arg)
    64  			}
    65  			argUse[argIdx] = true
    66  
    67  			// Prevent SQL injection via Line Comment Creation
    68  			// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
    69  			str = " " + str + " "
    70  		default:
    71  			return "", fmt.Errorf("invalid Part type: %T", part)
    72  		}
    73  		buf.WriteString(str)
    74  	}
    75  
    76  	for i, used := range argUse {
    77  		if !used {
    78  			return "", fmt.Errorf("unused argument: %d", i)
    79  		}
    80  	}
    81  	return buf.String(), nil
    82  }
    83  
    84  func NewQuery(sql string) (*Query, error) {
    85  	l := &sqlLexer{
    86  		src:     sql,
    87  		stateFn: rawState,
    88  	}
    89  
    90  	for l.stateFn != nil {
    91  		l.stateFn = l.stateFn(l)
    92  	}
    93  
    94  	query := &Query{Parts: l.parts}
    95  
    96  	return query, nil
    97  }
    98  
    99  func QuoteString(str string) string {
   100  	return "'" + strings.ReplaceAll(str, "'", "''") + "'"
   101  }
   102  
   103  func QuoteBytes(buf []byte) string {
   104  	return `'\x` + hex.EncodeToString(buf) + "'"
   105  }
   106  
   107  type sqlLexer struct {
   108  	src     string
   109  	start   int
   110  	pos     int
   111  	nested  int // multiline comment nesting level.
   112  	stateFn stateFn
   113  	parts   []Part
   114  }
   115  
   116  type stateFn func(*sqlLexer) stateFn
   117  
   118  func rawState(l *sqlLexer) stateFn {
   119  	for {
   120  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   121  		l.pos += width
   122  
   123  		switch r {
   124  		case 'e', 'E':
   125  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   126  			if nextRune == '\'' {
   127  				l.pos += width
   128  				return escapeStringState
   129  			}
   130  		case '\'':
   131  			return singleQuoteState
   132  		case '"':
   133  			return doubleQuoteState
   134  		case '$':
   135  			nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
   136  			if '0' <= nextRune && nextRune <= '9' {
   137  				if l.pos-l.start > 0 {
   138  					l.parts = append(l.parts, l.src[l.start:l.pos-width])
   139  				}
   140  				l.start = l.pos
   141  				return placeholderState
   142  			}
   143  		case '-':
   144  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   145  			if nextRune == '-' {
   146  				l.pos += width
   147  				return oneLineCommentState
   148  			}
   149  		case '/':
   150  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   151  			if nextRune == '*' {
   152  				l.pos += width
   153  				return multilineCommentState
   154  			}
   155  		case utf8.RuneError:
   156  			if width != replacementcharacterwidth {
   157  				if l.pos-l.start > 0 {
   158  					l.parts = append(l.parts, l.src[l.start:l.pos])
   159  					l.start = l.pos
   160  				}
   161  				return nil
   162  			}
   163  		}
   164  	}
   165  }
   166  
   167  func singleQuoteState(l *sqlLexer) stateFn {
   168  	for {
   169  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   170  		l.pos += width
   171  
   172  		switch r {
   173  		case '\'':
   174  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   175  			if nextRune != '\'' {
   176  				return rawState
   177  			}
   178  			l.pos += width
   179  		case utf8.RuneError:
   180  			if width != replacementcharacterwidth {
   181  				if l.pos-l.start > 0 {
   182  					l.parts = append(l.parts, l.src[l.start:l.pos])
   183  					l.start = l.pos
   184  				}
   185  				return nil
   186  			}
   187  		}
   188  	}
   189  }
   190  
   191  func doubleQuoteState(l *sqlLexer) stateFn {
   192  	for {
   193  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   194  		l.pos += width
   195  
   196  		switch r {
   197  		case '"':
   198  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   199  			if nextRune != '"' {
   200  				return rawState
   201  			}
   202  			l.pos += width
   203  		case utf8.RuneError:
   204  			if width != replacementcharacterwidth {
   205  				if l.pos-l.start > 0 {
   206  					l.parts = append(l.parts, l.src[l.start:l.pos])
   207  					l.start = l.pos
   208  				}
   209  				return nil
   210  			}
   211  		}
   212  	}
   213  }
   214  
   215  // placeholderState consumes a placeholder value. The $ must have already has
   216  // already been consumed. The first rune must be a digit.
   217  func placeholderState(l *sqlLexer) stateFn {
   218  	num := 0
   219  
   220  	for {
   221  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   222  		l.pos += width
   223  
   224  		if '0' <= r && r <= '9' {
   225  			num *= 10
   226  			num += int(r - '0')
   227  		} else {
   228  			l.parts = append(l.parts, num)
   229  			l.pos -= width
   230  			l.start = l.pos
   231  			return rawState
   232  		}
   233  	}
   234  }
   235  
   236  func escapeStringState(l *sqlLexer) stateFn {
   237  	for {
   238  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   239  		l.pos += width
   240  
   241  		switch r {
   242  		case '\\':
   243  			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
   244  			l.pos += width
   245  		case '\'':
   246  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   247  			if nextRune != '\'' {
   248  				return rawState
   249  			}
   250  			l.pos += width
   251  		case utf8.RuneError:
   252  			if width != replacementcharacterwidth {
   253  				if l.pos-l.start > 0 {
   254  					l.parts = append(l.parts, l.src[l.start:l.pos])
   255  					l.start = l.pos
   256  				}
   257  				return nil
   258  			}
   259  		}
   260  	}
   261  }
   262  
   263  func oneLineCommentState(l *sqlLexer) stateFn {
   264  	for {
   265  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   266  		l.pos += width
   267  
   268  		switch r {
   269  		case '\\':
   270  			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
   271  			l.pos += width
   272  		case '\n', '\r':
   273  			return rawState
   274  		case utf8.RuneError:
   275  			if width != replacementcharacterwidth {
   276  				if l.pos-l.start > 0 {
   277  					l.parts = append(l.parts, l.src[l.start:l.pos])
   278  					l.start = l.pos
   279  				}
   280  				return nil
   281  			}
   282  		}
   283  	}
   284  }
   285  
   286  func multilineCommentState(l *sqlLexer) stateFn {
   287  	for {
   288  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   289  		l.pos += width
   290  
   291  		switch r {
   292  		case '/':
   293  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   294  			if nextRune == '*' {
   295  				l.pos += width
   296  				l.nested++
   297  			}
   298  		case '*':
   299  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   300  			if nextRune != '/' {
   301  				continue
   302  			}
   303  
   304  			l.pos += width
   305  			if l.nested == 0 {
   306  				return rawState
   307  			}
   308  			l.nested--
   309  
   310  		case utf8.RuneError:
   311  			if width != replacementcharacterwidth {
   312  				if l.pos-l.start > 0 {
   313  					l.parts = append(l.parts, l.src[l.start:l.pos])
   314  					l.start = l.pos
   315  				}
   316  				return nil
   317  			}
   318  		}
   319  	}
   320  }
   321  
   322  // SanitizeSQL replaces placeholder values with args. It quotes and escapes args
   323  // as necessary. This function is only safe when standard_conforming_strings is
   324  // on.
   325  func SanitizeSQL(sql string, args ...any) (string, error) {
   326  	query, err := NewQuery(sql)
   327  	if err != nil {
   328  		return "", err
   329  	}
   330  	return query.Sanitize(args...)
   331  }
   332  

View as plain text