Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql support placeholder #73

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ import (

// hiveOptions for opened Hive sessions.
type hiveOptions struct {
PollIntervalSeconds int64
BatchSize int64
PollIntervalSeconds int64
BatchSize int64
ColumnsWithoutTableName bool // column names not contains table name
}

type hiveConnection struct {
thrift *hiveserver2.TCLIServiceClient
session *hiveserver2.TSessionHandle
options hiveOptions
ctx context.Context
thrift *hiveserver2.TCLIServiceClient
session *hiveserver2.TSessionHandle
options hiveOptions
ctx context.Context
paramsInterpolator *ParamsInterpolator
}

func (c *hiveConnection) Begin() (driver.Tx, error) {
Expand Down Expand Up @@ -81,6 +83,13 @@ func removeLastSemicolon(s string) string {
}

func (c *hiveConnection) execute(ctx context.Context, query string, args []driver.NamedValue) (*hiveserver2.TExecuteStatementResp, error) {
var err error
if len(args) != 0 {
query, err = c.paramsInterpolator.InterpolateNamedValue(query, args)
if err != nil {
return nil, err
}
}
executeReq := hiveserver2.NewTExecuteStatementReq()
executeReq.SessionHandle = c.session
executeReq.Statement = removeLastSemicolon(query)
Expand Down
15 changes: 10 additions & 5 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ func (d drv) Open(dsn string) (driver.Conn, error) {
return nil, err
}

options := hiveOptions{PollIntervalSeconds: 5, BatchSize: int64(cfg.Batch)}
options := hiveOptions{
PollIntervalSeconds: 5,
BatchSize: int64(cfg.Batch),
ColumnsWithoutTableName: cfg.ColumnsWithoutTableName,
}
conn := &hiveConnection{
thrift: client,
session: session.SessionHandle,
options: options,
ctx: context.Background(),
thrift: client,
session: session.SessionHandle,
options: options,
ctx: context.Background(),
paramsInterpolator: NewParamsInterpolator(),
}
return conn, nil
}
Expand Down
9 changes: 9 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gohive

import (
"database/sql"
"database/sql/driver"
"fmt"
"os"
"reflect"
Expand Down Expand Up @@ -163,3 +164,11 @@ func TestExec(t *testing.T) {
defer db.Close()
a.NoError(err)
}

func TestExecArgs(t *testing.T) {
a := assert.New(t)
db, _ := newDB("churn")
_, err := db.Exec("insert into churn.test (gender) values (?)", []driver.Value{"Female"})
defer db.Close()
a.NoError(err)
}
52 changes: 33 additions & 19 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ import (
)

type Config struct {
User string
Passwd string
Addr string
DBName string
Auth string
Batch int
SessionCfg map[string]string
User string
Passwd string
Addr string
DBName string
Auth string
Batch int
ColumnsWithoutTableName bool // column names not contains table name
SessionCfg map[string]string
}

var (
Expand All @@ -25,11 +26,12 @@ var (
)

const (
sessionConfPrefix = "session."
authConfName = "auth"
defaultAuth = "NOSASL"
batchSizeName = "batch"
defaultBatchSize = 10000
sessionConfPrefix = "session."
authConfName = "auth"
defaultAuth = "NOSASL"
batchSizeName = "batch"
columnsWithoutTableNameName = "columns_without_table_name"
defaultBatchSize = 10000
)

// ParseDSN requires DSN names in the format [user[:password]@]addr/dbname.
Expand Down Expand Up @@ -60,6 +62,8 @@ func ParseDSN(dsn string) (*Config, error) {

auth := defaultAuth
batch := defaultBatchSize
columnsWithoutTableName := false
var err error
sc := make(map[string]string)
if len(sub[3]) > 0 && sub[3][0] == '?' {
qry, _ := url.ParseQuery(sub[3][1:])
Expand All @@ -74,6 +78,12 @@ func ParseDSN(dsn string) (*Config, error) {
}
batch = bch
}
if v, found := qry[columnsWithoutTableNameName]; found {
columnsWithoutTableName, err = strconv.ParseBool(v[0])
if err != nil {
return nil, err
}
}

for k, v := range qry {
if strings.HasPrefix(k, sessionConfPrefix) {
Expand All @@ -83,13 +93,14 @@ func ParseDSN(dsn string) (*Config, error) {
}

return &Config{
User: user,
Passwd: passwd,
Addr: addr,
DBName: dbname,
Auth: auth,
Batch: batch,
SessionCfg: sc,
User: user,
Passwd: passwd,
Addr: addr,
DBName: dbname,
Auth: auth,
Batch: batch,
ColumnsWithoutTableName: columnsWithoutTableName,
SessionCfg: sc,
}, nil
}

Expand All @@ -103,6 +114,9 @@ func (cfg *Config) FormatDSN() string {
if len(cfg.Auth) > 0 {
dsn += fmt.Sprintf("&auth=%s", cfg.Auth)
}
if cfg.ColumnsWithoutTableName {
dsn += "&columns_without_table_name=true"
}
if len(cfg.SessionCfg) > 0 {
for k, v := range cfg.SessionCfg {
dsn += fmt.Sprintf("&%s%s=%s", sessionConfPrefix, k, v)
Expand Down
12 changes: 12 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,15 @@ func TestFormatDSNWithoutDBName(t *testing.T) {
ds2 := cfg.FormatDSN()
assert.Equal(t, ds2, ds)
}

func TestFormatDSNColumnsWithoutTableNameName(t *testing.T) {
ds := "user:[email protected]?columns_without_table_name=true"
cfg, e := ParseDSN(ds)
assert.Nil(t, e)
assert.True(t, cfg.ColumnsWithoutTableName)

ds2 := "user:[email protected]"
cfg2, e := ParseDSN(ds2)
assert.Nil(t, e)
assert.False(t, cfg2.ColumnsWithoutTableName)
}
201 changes: 201 additions & 0 deletions params_replacer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package gohive

import (
"database/sql/driver"
"encoding/hex"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
)

const (
TimeStampLayout = "2006-01-02 15:04:05.999999999"
DateLayout = "2006-01-02"
)

type ParamsInterpolator struct {
Local *time.Location
}

func NewParamsInterpolator() *ParamsInterpolator {
return &ParamsInterpolator{
Local: time.Local,
}
}

func (p *ParamsInterpolator) InterpolateNamedValue(query string, namedArgs []driver.NamedValue) (string, error) {
args, err := namedValueToValue(namedArgs)
if err != nil {
return "", err
}
return p.Interpolate(query, args)
}

func (p *ParamsInterpolator) Interpolate(query string, args []driver.Value) (string, error) {
if strings.Count(query, "?") != len(args) {
return "", fmt.Errorf("gohive driver: number of ? [%d] must be equal to len(args): [%d]",
strings.Count(query, "?"), len(args))
}

var err error

argIdx := 0
var buf = make([]byte, 0, len(query)+len(args)*15)
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q

arg := args[argIdx]
argIdx++

buf, err = p.interpolateOne(buf, arg)
if err != nil {
return "", fmt.Errorf("gohive driver: failed to interpolate failed: %w, args[%d]: [%v]",
err, argIdx, arg)
}

}
if argIdx != len(args) {
return "", fmt.Errorf("gohive driver: args are not all filled into SQL, argIdx: %d, total: %d",
argIdx, len(args))
}
return string(buf), nil

}

func (p *ParamsInterpolator) interpolateOne(buf []byte, arg driver.Value) ([]byte, error) {
if arg == nil {
buf = append(buf, "NULL"...)
return buf, nil
}

switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, "'true'"...)
} else {
buf = append(buf, "'false'"...)
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf = appendDateTime(buf, v.In(p.Local))
buf = append(buf, '\'')
}
case json.RawMessage:
buf = append(buf, '\'')
buf = appendBytes(buf, v)
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "X'"...)
buf = appendBytes(buf, v)
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
buf = escapeStringBackslash(buf, v)
buf = append(buf, '\'')
default:
return nil, fmt.Errorf("gohive driver: unexpected args type: %T", arg)
}
return buf, nil
}

func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
args := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, fmt.Errorf("gohive driver: driver does not support the use of Named Parameters")
}
args[n] = param.Value
}
return args, nil
}

func appendBytes(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)+hex.EncodedLen(len(v)))
pos += hex.Encode(buf[pos:], v)
return buf[:pos]
}

func appendDateTime(buf []byte, t time.Time) []byte {
buf = t.AppendFormat(buf, TimeStampLayout)
return buf
}

func escapeStringBackslash(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)

for i := 0; i < len(v); i++ {
c := v[i]
switch c {
case '\x00':
buf[pos+1] = '0'
buf[pos] = '\\'
pos += 2
case '\n':
buf[pos+1] = 'n'
buf[pos] = '\\'
pos += 2
case '\r':
buf[pos+1] = 'r'
buf[pos] = '\\'
pos += 2
case '\x1a':
buf[pos+1] = 'Z'
buf[pos] = '\\'
pos += 2
case '\'':
buf[pos+1] = '\''
buf[pos] = '\\'
pos += 2
case '"':
buf[pos+1] = '"'
buf[pos] = '\\'
pos += 2
case '\\':
buf[pos+1] = '\\'
buf[pos] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}

return buf[:pos]
}

// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
newSize := len(buf) + appendSize
if cap(buf) < newSize {
// Grow buffer exponentially
newBuf := make([]byte, len(buf)*2+appendSize)
copy(newBuf, buf)
buf = newBuf
}
return buf[:newSize]
}
Loading