-
Notifications
You must be signed in to change notification settings - Fork 0
/
insert.go
173 lines (162 loc) · 4 KB
/
insert.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package morm
import (
"context"
"github.com/NotFound1911/morm/errors"
"github.com/NotFound1911/morm/model"
)
type UpsertBuilder[T any] struct {
i *Inserter[T]
conflictColumns []Column
}
type Upsert struct {
conflictColumns []Column
assigns []Assignable
}
func (u *UpsertBuilder[T]) ConflictColumns(cols ...string) *UpsertBuilder[T] {
u.conflictColumns = make([]Column, len(cols))
for i, col := range cols {
u.conflictColumns[i] = Column{
name: col,
}
}
return u
}
func (u *UpsertBuilder[T]) Update(assigns ...Assignable) *Inserter[T] {
u.i.onDuplicate = &Upsert{
conflictColumns: u.conflictColumns,
assigns: assigns,
}
return u.i
}
type Inserter[T any] struct {
builder
values []*T // 插入值
columns []string // 指定列
onDuplicate *Upsert
sess session
}
// OnDuplicateKey 返回OnDuplicateKey构造部分
// 整体为 Inserter构造 --> OnDuplicateKey构造冲突部分 --> Inserter构造剩余部分
func (i *Inserter[T]) OnDuplicateKey() *UpsertBuilder[T] {
return &UpsertBuilder[T]{
i: i,
}
}
func NewInserter[T any](sess session) *Inserter[T] {
c := sess.getCore()
return &Inserter[T]{
sess: sess,
builder: builder{
core: c,
dialect: c.dialect,
quoter: c.dialect.quoter(),
},
}
}
// Values 要插入的值
func (i *Inserter[T]) Values(vals ...*T) *Inserter[T] {
i.values = vals
return i
}
func (i *Inserter[T]) Cloumns(cols ...string) *Inserter[T] {
i.columns = cols
return i
}
func (i *Inserter[T]) Build() (*Query, error) {
if len(i.values) == 0 {
return nil, errs.NewErrInsertZeroRow()
}
var (
t T
err error
)
i.model, err = i.r.Get(&t)
if err != nil {
return nil, err
}
i.sqlBuilder.WriteString("INSERT INTO ")
i.quote(i.model.TableName)
i.sqlBuilder.WriteString("(")
fields := i.model.Fields
if len(i.columns) != 0 { // 指定列
fields = make([]*model.Field, 0, len(i.columns))
for _, col := range i.columns { // 使用sql的顺序
field, ok := i.model.FieldMap[col]
if !ok {
return nil, errs.NewErrUnknownField(col)
}
fields = append(fields, field)
}
}
// (len(i.values) + 1) 中 +1 是考虑到 UPSERT 语句会传递额外的参数
i.args = make([]any, 0, len(fields)*(len(i.values)+1))
for idx, fd := range fields {
if idx > 0 {
i.sqlBuilder.WriteByte(',')
}
i.sqlBuilder.WriteByte('`')
i.sqlBuilder.WriteString(fd.ColName)
i.sqlBuilder.WriteByte('`')
}
i.sqlBuilder.WriteString(") VALUES")
for vIdx, val := range i.values { // 第一层便利值
if vIdx > 0 {
i.sqlBuilder.WriteByte(',')
}
refVal := i.valCreator(val, i.model)
i.sqlBuilder.WriteByte('(')
for fIdx, field := range fields { // 第二层便利字段
if fIdx > 0 {
i.sqlBuilder.WriteByte(',')
}
i.sqlBuilder.WriteByte('?')
fdVal, err := refVal.Field(field.GoName)
if err != nil {
return nil, err
}
i.addArgs(fdVal)
}
i.sqlBuilder.WriteByte(')')
}
// 构造冲突部分
if i.onDuplicate != nil {
if err := i.core.dialect.buildUpsert(&i.builder, i.onDuplicate); err != nil {
return nil, err
}
}
i.sqlBuilder.WriteByte(';')
return &Query{
SQL: i.sqlBuilder.String(),
Args: i.args,
}, nil
}
func (i *Inserter[T]) buildAssignment(a Assignable) error {
switch assign := a.(type) {
case Column:
i.sqlBuilder.WriteByte('`')
fd, ok := i.model.FieldMap[assign.name]
if !ok {
return errs.NewErrUnknownField(assign.name)
}
i.sqlBuilder.WriteString(fd.ColName)
i.sqlBuilder.WriteString("`=VALUES(`")
i.sqlBuilder.WriteString(fd.ColName)
i.sqlBuilder.WriteString("`)")
case Assignment:
i.sqlBuilder.WriteByte('`')
fd, ok := i.model.FieldMap[assign.name]
if !ok {
return errs.NewErrUnknownField(assign.name)
}
i.sqlBuilder.WriteString(fd.ColName)
i.sqlBuilder.WriteByte('`')
i.sqlBuilder.WriteString("=?")
i.addArgs(assign.val)
default:
return errs.NewErrUnsupportedAssignableType(a)
}
return nil
}
func (i *Inserter[T]) Exec(ctx context.Context) Result {
return exec(ctx, i.sess, i.core, &QueryContext{Builder: i, Type: "INSERT"})
}