1 package db
2
3 import (
4 "context"
5 "fmt"
6 "strings"
7 )
8
9
10
11
12
13 type MultiInserter struct {
14
15
16
17 table string
18 fields []string
19 returningColumn string
20
21 values [][]interface{}
22 }
23
24
25
26
27
28
29
30
31
32 func NewMultiInserter(table string, fields []string, returningColumn string) (*MultiInserter, error) {
33 if len(table) == 0 || len(fields) == 0 {
34 return nil, fmt.Errorf("empty table name or fields list")
35 }
36
37 err := validMariaDBUnquotedIdentifier(table)
38 if err != nil {
39 return nil, err
40 }
41 for _, field := range fields {
42 err := validMariaDBUnquotedIdentifier(field)
43 if err != nil {
44 return nil, err
45 }
46 }
47 if returningColumn != "" {
48 err := validMariaDBUnquotedIdentifier(returningColumn)
49 if err != nil {
50 return nil, err
51 }
52 }
53
54 return &MultiInserter{
55 table: table,
56 fields: fields,
57 returningColumn: returningColumn,
58 values: make([][]interface{}, 0),
59 }, nil
60 }
61
62
63 func (mi *MultiInserter) Add(row []interface{}) error {
64 if len(row) != len(mi.fields) {
65 return fmt.Errorf("field count mismatch, got %d, expected %d", len(row), len(mi.fields))
66 }
67 mi.values = append(mi.values, row)
68 return nil
69 }
70
71
72
73
74 func (mi *MultiInserter) query() (string, []interface{}) {
75 var questionsBuf strings.Builder
76 var queryArgs []interface{}
77 for _, row := range mi.values {
78
79
80
81 fmt.Fprintf(&questionsBuf, "(%s),", QuestionMarks(len(mi.fields)))
82 queryArgs = append(queryArgs, row...)
83 }
84
85 questions := strings.TrimRight(questionsBuf.String(), ",")
86
87
88
89
90 returning := ""
91 if mi.returningColumn != "" {
92 returning = fmt.Sprintf(" RETURNING %s", mi.returningColumn)
93 }
94
95
96
97
98
99 query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s%s", mi.table, strings.Join(mi.fields, ","), questions, returning)
100
101 return query, queryArgs
102 }
103
104
105
106
107 func (mi *MultiInserter) Insert(ctx context.Context, queryer Queryer) ([]int64, error) {
108 query, queryArgs := mi.query()
109 rows, err := queryer.QueryContext(ctx, query, queryArgs...)
110 if err != nil {
111 return nil, err
112 }
113
114 ids := make([]int64, 0, len(mi.values))
115 if mi.returningColumn != "" {
116 for rows.Next() {
117 var id int64
118 err = rows.Scan(&id)
119 if err != nil {
120 rows.Close()
121 return nil, err
122 }
123 ids = append(ids, id)
124 }
125 }
126
127
128
129
130
131 if rows != nil {
132 err = rows.Close()
133 if err != nil {
134 return nil, err
135 }
136 }
137
138 return ids, nil
139 }
140
View as plain text