...

Source file src/github.com/jackc/pgx/v5/named_args.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx
     2  
     3  import (
     4  	"context"
     5  	"strconv"
     6  	"strings"
     7  	"unicode/utf8"
     8  )
     9  
    10  // NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
    11  // ordinal placeholder and construct the appropriate arguments.
    12  //
    13  // For example, the following two queries are equivalent:
    14  //
    15  //	conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
    16  //	conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
    17  //
    18  // Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
    19  // letters, numbers, or underscores.
    20  type NamedArgs map[string]any
    21  
    22  // RewriteQuery implements the QueryRewriter interface.
    23  func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
    24  	l := &sqlLexer{
    25  		src:           sql,
    26  		stateFn:       rawState,
    27  		nameToOrdinal: make(map[namedArg]int, len(na)),
    28  	}
    29  
    30  	for l.stateFn != nil {
    31  		l.stateFn = l.stateFn(l)
    32  	}
    33  
    34  	sb := strings.Builder{}
    35  	for _, p := range l.parts {
    36  		switch p := p.(type) {
    37  		case string:
    38  			sb.WriteString(p)
    39  		case namedArg:
    40  			sb.WriteRune('$')
    41  			sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
    42  		}
    43  	}
    44  
    45  	newArgs = make([]any, len(l.nameToOrdinal))
    46  	for name, ordinal := range l.nameToOrdinal {
    47  		newArgs[ordinal-1] = na[string(name)]
    48  	}
    49  
    50  	return sb.String(), newArgs, nil
    51  }
    52  
    53  type namedArg string
    54  
    55  type sqlLexer struct {
    56  	src     string
    57  	start   int
    58  	pos     int
    59  	nested  int // multiline comment nesting level.
    60  	stateFn stateFn
    61  	parts   []any
    62  
    63  	nameToOrdinal map[namedArg]int
    64  }
    65  
    66  type stateFn func(*sqlLexer) stateFn
    67  
    68  func rawState(l *sqlLexer) stateFn {
    69  	for {
    70  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    71  		l.pos += width
    72  
    73  		switch r {
    74  		case 'e', 'E':
    75  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    76  			if nextRune == '\'' {
    77  				l.pos += width
    78  				return escapeStringState
    79  			}
    80  		case '\'':
    81  			return singleQuoteState
    82  		case '"':
    83  			return doubleQuoteState
    84  		case '@':
    85  			nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
    86  			if isLetter(nextRune) || nextRune == '_' {
    87  				if l.pos-l.start > 0 {
    88  					l.parts = append(l.parts, l.src[l.start:l.pos-width])
    89  				}
    90  				l.start = l.pos
    91  				return namedArgState
    92  			}
    93  		case '-':
    94  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    95  			if nextRune == '-' {
    96  				l.pos += width
    97  				return oneLineCommentState
    98  			}
    99  		case '/':
   100  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   101  			if nextRune == '*' {
   102  				l.pos += width
   103  				return multilineCommentState
   104  			}
   105  		case utf8.RuneError:
   106  			if l.pos-l.start > 0 {
   107  				l.parts = append(l.parts, l.src[l.start:l.pos])
   108  				l.start = l.pos
   109  			}
   110  			return nil
   111  		}
   112  	}
   113  }
   114  
   115  func isLetter(r rune) bool {
   116  	return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
   117  }
   118  
   119  func namedArgState(l *sqlLexer) stateFn {
   120  	for {
   121  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   122  		l.pos += width
   123  
   124  		if r == utf8.RuneError {
   125  			if l.pos-l.start > 0 {
   126  				na := namedArg(l.src[l.start:l.pos])
   127  				if _, found := l.nameToOrdinal[na]; !found {
   128  					l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
   129  				}
   130  				l.parts = append(l.parts, na)
   131  				l.start = l.pos
   132  			}
   133  			return nil
   134  		} else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') {
   135  			l.pos -= width
   136  			na := namedArg(l.src[l.start:l.pos])
   137  			if _, found := l.nameToOrdinal[na]; !found {
   138  				l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
   139  			}
   140  			l.parts = append(l.parts, namedArg(na))
   141  			l.start = l.pos
   142  			return rawState
   143  		}
   144  	}
   145  }
   146  
   147  func singleQuoteState(l *sqlLexer) stateFn {
   148  	for {
   149  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   150  		l.pos += width
   151  
   152  		switch r {
   153  		case '\'':
   154  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   155  			if nextRune != '\'' {
   156  				return rawState
   157  			}
   158  			l.pos += width
   159  		case utf8.RuneError:
   160  			if l.pos-l.start > 0 {
   161  				l.parts = append(l.parts, l.src[l.start:l.pos])
   162  				l.start = l.pos
   163  			}
   164  			return nil
   165  		}
   166  	}
   167  }
   168  
   169  func doubleQuoteState(l *sqlLexer) stateFn {
   170  	for {
   171  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   172  		l.pos += width
   173  
   174  		switch r {
   175  		case '"':
   176  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   177  			if nextRune != '"' {
   178  				return rawState
   179  			}
   180  			l.pos += width
   181  		case utf8.RuneError:
   182  			if l.pos-l.start > 0 {
   183  				l.parts = append(l.parts, l.src[l.start:l.pos])
   184  				l.start = l.pos
   185  			}
   186  			return nil
   187  		}
   188  	}
   189  }
   190  
   191  func escapeStringState(l *sqlLexer) stateFn {
   192  	for {
   193  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   194  		l.pos += width
   195  
   196  		switch r {
   197  		case '\\':
   198  			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
   199  			l.pos += width
   200  		case '\'':
   201  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   202  			if nextRune != '\'' {
   203  				return rawState
   204  			}
   205  			l.pos += width
   206  		case utf8.RuneError:
   207  			if l.pos-l.start > 0 {
   208  				l.parts = append(l.parts, l.src[l.start:l.pos])
   209  				l.start = l.pos
   210  			}
   211  			return nil
   212  		}
   213  	}
   214  }
   215  
   216  func oneLineCommentState(l *sqlLexer) stateFn {
   217  	for {
   218  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   219  		l.pos += width
   220  
   221  		switch r {
   222  		case '\\':
   223  			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
   224  			l.pos += width
   225  		case '\n', '\r':
   226  			return rawState
   227  		case utf8.RuneError:
   228  			if l.pos-l.start > 0 {
   229  				l.parts = append(l.parts, l.src[l.start:l.pos])
   230  				l.start = l.pos
   231  			}
   232  			return nil
   233  		}
   234  	}
   235  }
   236  
   237  func multilineCommentState(l *sqlLexer) stateFn {
   238  	for {
   239  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   240  		l.pos += width
   241  
   242  		switch r {
   243  		case '/':
   244  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   245  			if nextRune == '*' {
   246  				l.pos += width
   247  				l.nested++
   248  			}
   249  		case '*':
   250  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   251  			if nextRune != '/' {
   252  				continue
   253  			}
   254  
   255  			l.pos += width
   256  			if l.nested == 0 {
   257  				return rawState
   258  			}
   259  			l.nested--
   260  
   261  		case utf8.RuneError:
   262  			if l.pos-l.start > 0 {
   263  				l.parts = append(l.parts, l.src[l.start:l.pos])
   264  				l.start = l.pos
   265  			}
   266  			return nil
   267  		}
   268  	}
   269  }
   270  

View as plain text