...

Source file src/github.com/doug-martin/goqu/v9/sqlgen/expression_sql_generator.go

Documentation: github.com/doug-martin/goqu/v9/sqlgen

     1  package sqlgen
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"reflect"
     6  	"strconv"
     7  	"time"
     8  	"unicode/utf8"
     9  
    10  	"github.com/doug-martin/goqu/v9/exp"
    11  	"github.com/doug-martin/goqu/v9/internal/errors"
    12  	"github.com/doug-martin/goqu/v9/internal/sb"
    13  	"github.com/doug-martin/goqu/v9/internal/util"
    14  )
    15  
    16  type (
    17  	// An adapter interface to be used by a Dataset to generate SQL for a specific dialect.
    18  	// See DefaultAdapter for a concrete implementation and examples.
    19  	ExpressionSQLGenerator interface {
    20  		Dialect() string
    21  		Generate(b sb.SQLBuilder, val interface{})
    22  	}
    23  	// The default adapter. This class should be used when building a new adapter. When creating a new adapter you can
    24  	// either override methods, or more typically update default values.
    25  	// See (github.com/doug-martin/goqu/dialect/postgres)
    26  	expressionSQLGenerator struct {
    27  		dialect        string
    28  		dialectOptions *SQLDialectOptions
    29  	}
    30  )
    31  
    32  var (
    33  	replacementRune = '?'
    34  	TrueLiteral     = exp.NewLiteralExpression("TRUE")
    35  	FalseLiteral    = exp.NewLiteralExpression("FALSE")
    36  
    37  	ErrEmptyIdentifier = errors.New(
    38  		`a empty identifier was encountered, please specify a "schema", "table" or "column"`,
    39  	)
    40  	ErrUnexpectedNamedWindow = errors.New(`unexpected named window function`)
    41  	ErrEmptyCaseWhens        = errors.New(`when conditions not found for case statement`)
    42  )
    43  
    44  func errUnsupportedExpressionType(e exp.Expression) error {
    45  	return errors.New("unsupported expression type %T", e)
    46  }
    47  
    48  func errUnsupportedIdentifierExpression(t interface{}) error {
    49  	return errors.New("unexpected col type must be string or LiteralExpression received %T", t)
    50  }
    51  
    52  func errUnsupportedBooleanExpressionOperator(op exp.BooleanOperation) error {
    53  	return errors.New("boolean operator '%+v' not supported", op)
    54  }
    55  
    56  func errUnsupportedBitwiseExpressionOperator(op exp.BitwiseOperation) error {
    57  	return errors.New("bitwise operator '%+v' not supported", op)
    58  }
    59  
    60  func errUnsupportedRangeExpressionOperator(op exp.RangeOperation) error {
    61  	return errors.New("range operator %+v not supported", op)
    62  }
    63  
    64  func errLateralNotSupported(dialect string) error {
    65  	return errors.New("dialect does not support lateral expressions [dialect=%s]", dialect)
    66  }
    67  
    68  func NewExpressionSQLGenerator(dialect string, do *SQLDialectOptions) ExpressionSQLGenerator {
    69  	return &expressionSQLGenerator{dialect: dialect, dialectOptions: do}
    70  }
    71  
    72  func (esg *expressionSQLGenerator) Dialect() string {
    73  	return esg.dialect
    74  }
    75  
    76  var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
    77  
    78  func (esg *expressionSQLGenerator) Generate(b sb.SQLBuilder, val interface{}) {
    79  	if b.Error() != nil {
    80  		return
    81  	}
    82  	if val == nil {
    83  		esg.literalNil(b)
    84  		return
    85  	}
    86  	switch v := val.(type) {
    87  	case exp.Expression:
    88  		esg.expressionSQL(b, v)
    89  	case int:
    90  		esg.literalInt(b, int64(v))
    91  	case int32:
    92  		esg.literalInt(b, int64(v))
    93  	case int64:
    94  		esg.literalInt(b, v)
    95  	case float32:
    96  		esg.literalFloat(b, float64(v))
    97  	case float64:
    98  		esg.literalFloat(b, v)
    99  	case string:
   100  		esg.literalString(b, v)
   101  	case bool:
   102  		esg.literalBool(b, v)
   103  	case time.Time:
   104  		esg.literalTime(b, v)
   105  	case *time.Time:
   106  		if v == nil {
   107  			esg.literalNil(b)
   108  			return
   109  		}
   110  		esg.literalTime(b, *v)
   111  	case driver.Valuer:
   112  		// See https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870
   113  		if rv := reflect.ValueOf(val); rv.Kind() == reflect.Ptr &&
   114  			rv.IsNil() &&
   115  			rv.Type().Elem().Implements(valuerReflectType) {
   116  			esg.literalNil(b)
   117  			return
   118  		}
   119  		dVal, err := v.Value()
   120  		if err != nil {
   121  			b.SetError(err)
   122  			return
   123  		}
   124  		esg.Generate(b, dVal)
   125  	default:
   126  		esg.reflectSQL(b, val)
   127  	}
   128  }
   129  
   130  func (esg *expressionSQLGenerator) reflectSQL(b sb.SQLBuilder, val interface{}) {
   131  	v := reflect.Indirect(reflect.ValueOf(val))
   132  	valKind := v.Kind()
   133  	switch {
   134  	case util.IsInvalid(valKind):
   135  		esg.literalNil(b)
   136  	case util.IsSlice(valKind):
   137  		switch t := val.(type) {
   138  		case []byte:
   139  			esg.literalBytes(b, t)
   140  		case []exp.CommonTableExpression:
   141  			esg.commonTablesSliceSQL(b, t)
   142  		default:
   143  			esg.sliceValueSQL(b, v)
   144  		}
   145  	case util.IsInt(valKind):
   146  		esg.Generate(b, v.Int())
   147  	case util.IsUint(valKind):
   148  		esg.Generate(b, int64(v.Uint()))
   149  	case util.IsFloat(valKind):
   150  		esg.Generate(b, v.Float())
   151  	case util.IsString(valKind):
   152  		esg.Generate(b, v.String())
   153  	case util.IsBool(valKind):
   154  		esg.Generate(b, v.Bool())
   155  	default:
   156  		b.SetError(errors.NewEncodeError(val))
   157  	}
   158  }
   159  
   160  // nolint:gocyclo // not complex just long
   161  func (esg *expressionSQLGenerator) expressionSQL(b sb.SQLBuilder, expression exp.Expression) {
   162  	switch e := expression.(type) {
   163  	case exp.ColumnListExpression:
   164  		esg.columnListSQL(b, e)
   165  	case exp.ExpressionList:
   166  		esg.expressionListSQL(b, e)
   167  	case exp.LiteralExpression:
   168  		esg.literalExpressionSQL(b, e)
   169  	case exp.IdentifierExpression:
   170  		esg.identifierExpressionSQL(b, e)
   171  	case exp.LateralExpression:
   172  		esg.lateralExpressionSQL(b, e)
   173  	case exp.AliasedExpression:
   174  		esg.aliasedExpressionSQL(b, e)
   175  	case exp.BooleanExpression:
   176  		esg.booleanExpressionSQL(b, e)
   177  	case exp.BitwiseExpression:
   178  		esg.bitwiseExpressionSQL(b, e)
   179  	case exp.RangeExpression:
   180  		esg.rangeExpressionSQL(b, e)
   181  	case exp.OrderedExpression:
   182  		esg.orderedExpressionSQL(b, e)
   183  	case exp.UpdateExpression:
   184  		esg.updateExpressionSQL(b, e)
   185  	case exp.SQLFunctionExpression:
   186  		esg.sqlFunctionExpressionSQL(b, e)
   187  	case exp.SQLWindowFunctionExpression:
   188  		esg.sqlWindowFunctionExpression(b, e)
   189  	case exp.WindowExpression:
   190  		esg.windowExpressionSQL(b, e)
   191  	case exp.CastExpression:
   192  		esg.castExpressionSQL(b, e)
   193  	case exp.AppendableExpression:
   194  		esg.appendableExpressionSQL(b, e)
   195  	case exp.CommonTableExpression:
   196  		esg.commonTableExpressionSQL(b, e)
   197  	case exp.CompoundExpression:
   198  		esg.compoundExpressionSQL(b, e)
   199  	case exp.CaseExpression:
   200  		esg.caseExpressionSQL(b, e)
   201  	case exp.Ex:
   202  		esg.expressionMapSQL(b, e)
   203  	case exp.ExOr:
   204  		esg.expressionOrMapSQL(b, e)
   205  	default:
   206  		b.SetError(errUnsupportedExpressionType(e))
   207  	}
   208  }
   209  
   210  // Generates a placeholder (e.g. ?, $1)
   211  func (esg *expressionSQLGenerator) placeHolderSQL(b sb.SQLBuilder, i interface{}) {
   212  	b.Write(esg.dialectOptions.PlaceHolderFragment)
   213  	if esg.dialectOptions.IncludePlaceholderNum {
   214  		b.WriteStrings(strconv.FormatInt(int64(b.CurrentArgPosition()), 10))
   215  	}
   216  	b.WriteArg(i)
   217  }
   218  
   219  // Generates creates the sql for a sub select on a Dataset
   220  func (esg *expressionSQLGenerator) appendableExpressionSQL(b sb.SQLBuilder, a exp.AppendableExpression) {
   221  	b.WriteRunes(esg.dialectOptions.LeftParenRune)
   222  	a.AppendSQL(b)
   223  	b.WriteRunes(esg.dialectOptions.RightParenRune)
   224  	if a.GetAs() != nil {
   225  		b.Write(esg.dialectOptions.AsFragment)
   226  		esg.Generate(b, a.GetAs())
   227  	}
   228  }
   229  
   230  // Quotes an identifier (e.g. "col", "table"."col"
   231  func (esg *expressionSQLGenerator) identifierExpressionSQL(b sb.SQLBuilder, ident exp.IdentifierExpression) {
   232  	if ident.IsEmpty() {
   233  		b.SetError(ErrEmptyIdentifier)
   234  		return
   235  	}
   236  	schema, table, col := ident.GetSchema(), ident.GetTable(), ident.GetCol()
   237  	if schema != esg.dialectOptions.EmptyString {
   238  		b.WriteRunes(esg.dialectOptions.QuoteRune).
   239  			WriteStrings(schema).
   240  			WriteRunes(esg.dialectOptions.QuoteRune)
   241  	}
   242  	if table != esg.dialectOptions.EmptyString {
   243  		if schema != esg.dialectOptions.EmptyString {
   244  			b.WriteRunes(esg.dialectOptions.PeriodRune)
   245  		}
   246  		b.WriteRunes(esg.dialectOptions.QuoteRune).
   247  			WriteStrings(table).
   248  			WriteRunes(esg.dialectOptions.QuoteRune)
   249  	}
   250  	switch t := col.(type) {
   251  	case nil:
   252  	case string:
   253  		if col != esg.dialectOptions.EmptyString {
   254  			if table != esg.dialectOptions.EmptyString || schema != esg.dialectOptions.EmptyString {
   255  				b.WriteRunes(esg.dialectOptions.PeriodRune)
   256  			}
   257  			b.WriteRunes(esg.dialectOptions.QuoteRune).
   258  				WriteStrings(t).
   259  				WriteRunes(esg.dialectOptions.QuoteRune)
   260  		}
   261  	case exp.LiteralExpression:
   262  		if table != esg.dialectOptions.EmptyString || schema != esg.dialectOptions.EmptyString {
   263  			b.WriteRunes(esg.dialectOptions.PeriodRune)
   264  		}
   265  		esg.Generate(b, t)
   266  	default:
   267  		b.SetError(errUnsupportedIdentifierExpression(col))
   268  	}
   269  }
   270  
   271  func (esg *expressionSQLGenerator) lateralExpressionSQL(b sb.SQLBuilder, le exp.LateralExpression) {
   272  	if !esg.dialectOptions.SupportsLateral {
   273  		b.SetError(errLateralNotSupported(esg.dialect))
   274  		return
   275  	}
   276  	b.Write(esg.dialectOptions.LateralFragment)
   277  	esg.Generate(b, le.Table())
   278  }
   279  
   280  // Generates SQL NULL value
   281  func (esg *expressionSQLGenerator) literalNil(b sb.SQLBuilder) {
   282  	if b.IsPrepared() {
   283  		esg.placeHolderSQL(b, nil)
   284  		return
   285  	}
   286  	b.Write(esg.dialectOptions.Null)
   287  }
   288  
   289  // Generates SQL bool literal, (e.g. TRUE, FALSE, mysql 1, 0, sqlite3 1, 0)
   290  func (esg *expressionSQLGenerator) literalBool(b sb.SQLBuilder, bl bool) {
   291  	if b.IsPrepared() {
   292  		esg.placeHolderSQL(b, bl)
   293  		return
   294  	}
   295  	if bl {
   296  		b.Write(esg.dialectOptions.True)
   297  	} else {
   298  		b.Write(esg.dialectOptions.False)
   299  	}
   300  }
   301  
   302  // Generates SQL for a time.Time value
   303  func (esg *expressionSQLGenerator) literalTime(b sb.SQLBuilder, t time.Time) {
   304  	if b.IsPrepared() {
   305  		esg.placeHolderSQL(b, t)
   306  		return
   307  	}
   308  	esg.Generate(b, t.In(timeLocation).Format(esg.dialectOptions.TimeFormat))
   309  }
   310  
   311  // Generates SQL for a Float Value
   312  func (esg *expressionSQLGenerator) literalFloat(b sb.SQLBuilder, f float64) {
   313  	if b.IsPrepared() {
   314  		esg.placeHolderSQL(b, f)
   315  		return
   316  	}
   317  	b.WriteStrings(strconv.FormatFloat(f, 'f', -1, 64))
   318  }
   319  
   320  // Generates SQL for an int value
   321  func (esg *expressionSQLGenerator) literalInt(b sb.SQLBuilder, i int64) {
   322  	if b.IsPrepared() {
   323  		esg.placeHolderSQL(b, i)
   324  		return
   325  	}
   326  	b.WriteStrings(strconv.FormatInt(i, 10))
   327  }
   328  
   329  // Generates SQL for a string
   330  func (esg *expressionSQLGenerator) literalString(b sb.SQLBuilder, s string) {
   331  	if b.IsPrepared() {
   332  		esg.placeHolderSQL(b, s)
   333  		return
   334  	}
   335  	b.WriteRunes(esg.dialectOptions.StringQuote)
   336  	for _, char := range s {
   337  		if e, ok := esg.dialectOptions.EscapedRunes[char]; ok {
   338  			b.Write(e)
   339  		} else {
   340  			b.WriteRunes(char)
   341  		}
   342  	}
   343  
   344  	b.WriteRunes(esg.dialectOptions.StringQuote)
   345  }
   346  
   347  // Generates SQL for a slice of bytes
   348  func (esg *expressionSQLGenerator) literalBytes(b sb.SQLBuilder, bs []byte) {
   349  	if b.IsPrepared() {
   350  		esg.placeHolderSQL(b, bs)
   351  		return
   352  	}
   353  	b.WriteRunes(esg.dialectOptions.StringQuote)
   354  	i := 0
   355  	for len(bs) > 0 {
   356  		char, l := utf8.DecodeRune(bs)
   357  		if e, ok := esg.dialectOptions.EscapedRunes[char]; ok {
   358  			b.Write(e)
   359  		} else {
   360  			b.WriteRunes(char)
   361  		}
   362  		i++
   363  		bs = bs[l:]
   364  	}
   365  	b.WriteRunes(esg.dialectOptions.StringQuote)
   366  }
   367  
   368  // Generates SQL for a slice of values (e.g. []int64{1,2,3,4} -> (1,2,3,4)
   369  func (esg *expressionSQLGenerator) sliceValueSQL(b sb.SQLBuilder, slice reflect.Value) {
   370  	b.WriteRunes(esg.dialectOptions.LeftParenRune)
   371  	for i, l := 0, slice.Len(); i < l; i++ {
   372  		esg.Generate(b, slice.Index(i).Interface())
   373  		if i < l-1 {
   374  			b.WriteRunes(esg.dialectOptions.CommaRune, esg.dialectOptions.SpaceRune)
   375  		}
   376  	}
   377  	b.WriteRunes(esg.dialectOptions.RightParenRune)
   378  }
   379  
   380  // Generates SQL for an AliasedExpression (e.g. I("a").As("b") -> "a" AS "b")
   381  func (esg *expressionSQLGenerator) aliasedExpressionSQL(b sb.SQLBuilder, aliased exp.AliasedExpression) {
   382  	esg.Generate(b, aliased.Aliased())
   383  	b.Write(esg.dialectOptions.AsFragment)
   384  	esg.Generate(b, aliased.GetAs())
   385  }
   386  
   387  // Generates SQL for a BooleanExpresion (e.g. I("a").Eq(2) -> "a" = 2)
   388  func (esg *expressionSQLGenerator) booleanExpressionSQL(b sb.SQLBuilder, operator exp.BooleanExpression) {
   389  	b.WriteRunes(esg.dialectOptions.LeftParenRune)
   390  	esg.Generate(b, operator.LHS())
   391  	b.WriteRunes(esg.dialectOptions.SpaceRune)
   392  	operatorOp := operator.Op()
   393  	if val, ok := esg.dialectOptions.BooleanOperatorLookup[operatorOp]; ok {
   394  		b.Write(val)
   395  	} else {
   396  		b.SetError(errUnsupportedBooleanExpressionOperator(operatorOp))
   397  		return
   398  	}
   399  	rhs := operator.RHS()
   400  
   401  	if (operatorOp == exp.IsOp || operatorOp == exp.IsNotOp) && rhs != nil && !esg.dialectOptions.BooleanDataTypeSupported {
   402  		b.SetError(errors.New("boolean data type is not supported by dialect %q", esg.dialect))
   403  		return
   404  	}
   405  
   406  	if (operatorOp == exp.IsOp || operatorOp == exp.IsNotOp) && esg.dialectOptions.UseLiteralIsBools {
   407  		// these values must be interpolated because preparing them generates invalid SQL
   408  		switch rhs {
   409  		case true:
   410  			rhs = TrueLiteral
   411  		case false:
   412  			rhs = FalseLiteral
   413  		case nil:
   414  			rhs = exp.NewLiteralExpression(string(esg.dialectOptions.Null))
   415  		}
   416  	}
   417  	b.WriteRunes(esg.dialectOptions.SpaceRune)
   418  
   419  	if (operatorOp == exp.IsOp || operatorOp == exp.IsNotOp) && rhs == nil && !esg.dialectOptions.BooleanDataTypeSupported {
   420  		// e.g. for SQL server dialect which does not support "IS @p1" for "IS NULL"
   421  		b.Write(esg.dialectOptions.Null)
   422  	} else {
   423  		esg.Generate(b, rhs)
   424  	}
   425  
   426  	b.WriteRunes(esg.dialectOptions.RightParenRune)
   427  }
   428  
   429  // Generates SQL for a BitwiseExpresion (e.g. I("a").BitwiseOr(2) - > "a" | 2)
   430  func (esg *expressionSQLGenerator) bitwiseExpressionSQL(b sb.SQLBuilder, operator exp.BitwiseExpression) {
   431  	b.WriteRunes(esg.dialectOptions.LeftParenRune)
   432  
   433  	if operator.LHS() != nil {
   434  		esg.Generate(b, operator.LHS())
   435  		b.WriteRunes(esg.dialectOptions.SpaceRune)
   436  	}
   437  
   438  	operatorOp := operator.Op()
   439  	if val, ok := esg.dialectOptions.BitwiseOperatorLookup[operatorOp]; ok {
   440  		b.Write(val)
   441  	} else {
   442  		b.SetError(errUnsupportedBitwiseExpressionOperator(operatorOp))
   443  		return
   444  	}
   445  
   446  	b.WriteRunes(esg.dialectOptions.SpaceRune)
   447  	esg.Generate(b, operator.RHS())
   448  	b.WriteRunes(esg.dialectOptions.RightParenRune)
   449  }
   450  
   451  // Generates SQL for a RangeExpresion (e.g. I("a").Between(RangeVal{Start:2,End:5}) -> "a" BETWEEN 2 AND 5)
   452  func (esg *expressionSQLGenerator) rangeExpressionSQL(b sb.SQLBuilder, operator exp.RangeExpression) {
   453  	b.WriteRunes(esg.dialectOptions.LeftParenRune)
   454  	esg.Generate(b, operator.LHS())
   455  	b.WriteRunes(esg.dialectOptions.SpaceRune)
   456  	operatorOp := operator.Op()
   457  	if val, ok := esg.dialectOptions.RangeOperatorLookup[operatorOp]; ok {
   458  		b.Write(val)
   459  	} else {
   460  		b.SetError(errUnsupportedRangeExpressionOperator(operatorOp))
   461  		return
   462  	}
   463  	rhs := operator.RHS()
   464  	b.WriteRunes(esg.dialectOptions.SpaceRune)
   465  	esg.Generate(b, rhs.Start())
   466  	b.Write(esg.dialectOptions.AndFragment)
   467  	esg.Generate(b, rhs.End())
   468  	b.WriteRunes(esg.dialectOptions.RightParenRune)
   469  }
   470  
   471  // Generates SQL for an OrderedExpression (e.g. I("a").Asc() -> "a" ASC)
   472  func (esg *expressionSQLGenerator) orderedExpressionSQL(b sb.SQLBuilder, order exp.OrderedExpression) {
   473  	esg.Generate(b, order.SortExpression())
   474  	if order.IsAsc() {
   475  		b.Write(esg.dialectOptions.AscFragment)
   476  	} else {
   477  		b.Write(esg.dialectOptions.DescFragment)
   478  	}
   479  	switch order.NullSortType() {
   480  	case exp.NoNullsSortType:
   481  		return
   482  	case exp.NullsFirstSortType:
   483  		b.Write(esg.dialectOptions.NullsFirstFragment)
   484  	case exp.NullsLastSortType:
   485  		b.Write(esg.dialectOptions.NullsLastFragment)
   486  	}
   487  }
   488  
   489  // Generates SQL for an ExpressionList (e.g. And(I("a").Eq("a"), I("b").Eq("b")) -> (("a" = 'a') AND ("b" = 'b')))
   490  func (esg *expressionSQLGenerator) expressionListSQL(b sb.SQLBuilder, expressionList exp.ExpressionList) {
   491  	if expressionList.IsEmpty() {
   492  		return
   493  	}
   494  	var op []byte
   495  	if expressionList.Type() == exp.AndType {
   496  		op = esg.dialectOptions.AndFragment
   497  	} else {
   498  		op = esg.dialectOptions.OrFragment
   499  	}
   500  	exps := expressionList.Expressions()
   501  	expLen := len(exps) - 1
   502  	if expLen > 0 {
   503  		b.WriteRunes(esg.dialectOptions.LeftParenRune)
   504  	} else {
   505  		esg.Generate(b, exps[0])
   506  		return
   507  	}
   508  	for i, e := range exps {
   509  		esg.Generate(b, e)
   510  		if i < expLen {
   511  			b.Write(op)
   512  		}
   513  	}
   514  	b.WriteRunes(esg.dialectOptions.RightParenRune)
   515  }
   516  
   517  // Generates SQL for a ColumnListExpression
   518  func (esg *expressionSQLGenerator) columnListSQL(b sb.SQLBuilder, columnList exp.ColumnListExpression) {
   519  	cols := columnList.Columns()
   520  	colLen := len(cols)
   521  	for i, col := range cols {
   522  		esg.Generate(b, col)
   523  		if i < colLen-1 {
   524  			b.WriteRunes(esg.dialectOptions.CommaRune, esg.dialectOptions.SpaceRune)
   525  		}
   526  	}
   527  }
   528  
   529  // Generates SQL for an UpdateEpxresion
   530  func (esg *expressionSQLGenerator) updateExpressionSQL(b sb.SQLBuilder, update exp.UpdateExpression) {
   531  	esg.Generate(b, update.Col())
   532  	b.WriteRunes(esg.dialectOptions.SetOperatorRune)
   533  	esg.Generate(b, update.Val())
   534  }
   535  
   536  // Generates SQL for a LiteralExpression
   537  //    L("a + b") -> a + b
   538  //    L("a = ?", 1) -> a = 1
   539  func (esg *expressionSQLGenerator) literalExpressionSQL(b sb.SQLBuilder, literal exp.LiteralExpression) {
   540  	l := literal.Literal()
   541  	args := literal.Args()
   542  	if argsLen := len(args); argsLen > 0 {
   543  		currIndex := 0
   544  		for _, char := range l {
   545  			if char == replacementRune && currIndex < argsLen {
   546  				esg.Generate(b, args[currIndex])
   547  				currIndex++
   548  			} else {
   549  				b.WriteRunes(char)
   550  			}
   551  		}
   552  		return
   553  	}
   554  	b.WriteStrings(l)
   555  }
   556  
   557  // Generates SQL for a SQLFunctionExpression
   558  //   COUNT(I("a")) -> COUNT("a")
   559  func (esg *expressionSQLGenerator) sqlFunctionExpressionSQL(b sb.SQLBuilder, sqlFunc exp.SQLFunctionExpression) {
   560  	b.WriteStrings(sqlFunc.Name())
   561  	esg.Generate(b, sqlFunc.Args())
   562  }
   563  
   564  func (esg *expressionSQLGenerator) sqlWindowFunctionExpression(b sb.SQLBuilder, sqlWinFunc exp.SQLWindowFunctionExpression) {
   565  	if !esg.dialectOptions.SupportsWindowFunction {
   566  		b.SetError(ErrWindowNotSupported(esg.dialect))
   567  		return
   568  	}
   569  	esg.Generate(b, sqlWinFunc.Func())
   570  	b.Write(esg.dialectOptions.WindowOverFragment)
   571  	switch {
   572  	case sqlWinFunc.HasWindowName():
   573  		esg.Generate(b, sqlWinFunc.WindowName())
   574  	case sqlWinFunc.HasWindow():
   575  		if sqlWinFunc.Window().HasName() {
   576  			b.SetError(ErrUnexpectedNamedWindow)
   577  			return
   578  		}
   579  		esg.Generate(b, sqlWinFunc.Window())
   580  	default:
   581  		esg.Generate(b, exp.NewWindowExpression(nil, nil, nil, nil))
   582  	}
   583  }
   584  
   585  func (esg *expressionSQLGenerator) windowExpressionSQL(b sb.SQLBuilder, we exp.WindowExpression) {
   586  	if !esg.dialectOptions.SupportsWindowFunction {
   587  		b.SetError(ErrWindowNotSupported(esg.dialect))
   588  		return
   589  	}
   590  	if we.HasName() {
   591  		esg.Generate(b, we.Name())
   592  		b.Write(esg.dialectOptions.AsFragment)
   593  	}
   594  	b.WriteRunes(esg.dialectOptions.LeftParenRune)
   595  
   596  	hasPartition := we.HasPartitionBy()
   597  	hasOrder := we.HasOrder()
   598  
   599  	if we.HasParent() {
   600  		esg.Generate(b, we.Parent())
   601  		if hasPartition || hasOrder {
   602  			b.WriteRunes(esg.dialectOptions.SpaceRune)
   603  		}
   604  	}
   605  
   606  	if hasPartition {
   607  		b.Write(esg.dialectOptions.WindowPartitionByFragment)
   608  		esg.Generate(b, we.PartitionCols())
   609  		if hasOrder {
   610  			b.WriteRunes(esg.dialectOptions.SpaceRune)
   611  		}
   612  	}
   613  	if hasOrder {
   614  		b.Write(esg.dialectOptions.WindowOrderByFragment)
   615  		esg.Generate(b, we.OrderCols())
   616  	}
   617  
   618  	b.WriteRunes(esg.dialectOptions.RightParenRune)
   619  }
   620  
   621  // Generates SQL for a CastExpression
   622  //   I("a").Cast("NUMERIC") -> CAST("a" AS NUMERIC)
   623  func (esg *expressionSQLGenerator) castExpressionSQL(b sb.SQLBuilder, cast exp.CastExpression) {
   624  	b.Write(esg.dialectOptions.CastFragment).WriteRunes(esg.dialectOptions.LeftParenRune)
   625  	esg.Generate(b, cast.Casted())
   626  	b.Write(esg.dialectOptions.AsFragment)
   627  	esg.Generate(b, cast.Type())
   628  	b.WriteRunes(esg.dialectOptions.RightParenRune)
   629  }
   630  
   631  // Generates the sql for the WITH clauses for common table expressions (CTE)
   632  func (esg *expressionSQLGenerator) commonTablesSliceSQL(b sb.SQLBuilder, ctes []exp.CommonTableExpression) {
   633  	l := len(ctes)
   634  	if l == 0 {
   635  		return
   636  	}
   637  	if !esg.dialectOptions.SupportsWithCTE {
   638  		b.SetError(ErrCTENotSupported(esg.dialect))
   639  		return
   640  	}
   641  	b.Write(esg.dialectOptions.WithFragment)
   642  	anyRecursive := false
   643  	for _, cte := range ctes {
   644  		anyRecursive = anyRecursive || cte.IsRecursive()
   645  	}
   646  	if anyRecursive {
   647  		if !esg.dialectOptions.SupportsWithCTERecursive {
   648  			b.SetError(ErrRecursiveCTENotSupported(esg.dialect))
   649  			return
   650  		}
   651  		b.Write(esg.dialectOptions.RecursiveFragment)
   652  	}
   653  	for i, cte := range ctes {
   654  		esg.Generate(b, cte)
   655  		if i < l-1 {
   656  			b.WriteRunes(esg.dialectOptions.CommaRune, esg.dialectOptions.SpaceRune)
   657  		}
   658  	}
   659  	b.WriteRunes(esg.dialectOptions.SpaceRune)
   660  }
   661  
   662  // Generates SQL for a CommonTableExpression
   663  func (esg *expressionSQLGenerator) commonTableExpressionSQL(b sb.SQLBuilder, cte exp.CommonTableExpression) {
   664  	esg.Generate(b, cte.Name())
   665  	b.Write(esg.dialectOptions.AsFragment)
   666  	esg.Generate(b, cte.SubQuery())
   667  }
   668  
   669  // Generates SQL for a CompoundExpression
   670  func (esg *expressionSQLGenerator) compoundExpressionSQL(b sb.SQLBuilder, compound exp.CompoundExpression) {
   671  	switch compound.Type() {
   672  	case exp.UnionCompoundType:
   673  		b.Write(esg.dialectOptions.UnionFragment)
   674  	case exp.UnionAllCompoundType:
   675  		b.Write(esg.dialectOptions.UnionAllFragment)
   676  	case exp.IntersectCompoundType:
   677  		b.Write(esg.dialectOptions.IntersectFragment)
   678  	case exp.IntersectAllCompoundType:
   679  		b.Write(esg.dialectOptions.IntersectAllFragment)
   680  	}
   681  	if esg.dialectOptions.WrapCompoundsInParens {
   682  		b.WriteRunes(esg.dialectOptions.LeftParenRune)
   683  		compound.RHS().AppendSQL(b)
   684  		b.WriteRunes(esg.dialectOptions.RightParenRune)
   685  	} else {
   686  		compound.RHS().AppendSQL(b)
   687  	}
   688  }
   689  
   690  // Generates SQL for a CaseExpression
   691  func (esg *expressionSQLGenerator) caseExpressionSQL(b sb.SQLBuilder, caseExpression exp.CaseExpression) {
   692  	caseVal := caseExpression.GetValue()
   693  	whens := caseExpression.GetWhens()
   694  	elseResult := caseExpression.GetElse()
   695  
   696  	if len(whens) == 0 {
   697  		b.SetError(ErrEmptyCaseWhens)
   698  		return
   699  	}
   700  	b.Write(esg.dialectOptions.CaseFragment)
   701  	if caseVal != nil {
   702  		esg.Generate(b, caseVal)
   703  	}
   704  	for _, when := range whens {
   705  		b.Write(esg.dialectOptions.WhenFragment)
   706  		esg.Generate(b, when.Condition())
   707  		b.Write(esg.dialectOptions.ThenFragment)
   708  		esg.Generate(b, when.Result())
   709  	}
   710  	if elseResult != nil {
   711  		b.Write(esg.dialectOptions.ElseFragment)
   712  		esg.Generate(b, elseResult.Result())
   713  	}
   714  	b.Write(esg.dialectOptions.EndFragment)
   715  }
   716  
   717  func (esg *expressionSQLGenerator) expressionMapSQL(b sb.SQLBuilder, ex exp.Ex) {
   718  	expressionList, err := ex.ToExpressions()
   719  	if err != nil {
   720  		b.SetError(err)
   721  		return
   722  	}
   723  	esg.Generate(b, expressionList)
   724  }
   725  
   726  func (esg *expressionSQLGenerator) expressionOrMapSQL(b sb.SQLBuilder, ex exp.ExOr) {
   727  	expressionList, err := ex.ToExpressions()
   728  	if err != nil {
   729  		b.SetError(err)
   730  		return
   731  	}
   732  	esg.Generate(b, expressionList)
   733  }
   734  

View as plain text