Skip to content

Commit

Permalink
[patch-1] chore: use cast-safe option
Browse files Browse the repository at this point in the history
  • Loading branch information
KoNekoD committed Dec 1, 2024
1 parent 70c3f3b commit 2aa476d
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 1 deletion.
15 changes: 14 additions & 1 deletion named_args.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,20 @@ func rawState(l *sqlLexer) stateFn {
return singleQuoteState
case '"':
return doubleQuoteState
case '@', ':':
case ':':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
prevRune := rune(0)
if l.pos > 1 {
prevRune, _ = utf8.DecodeRuneInString(l.src[l.pos-2:])
}
if nextRune != ':' && prevRune != ':' && (isLetter(nextRune) || nextRune == '_') {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return namedArgState
}
case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) || nextRune == '_' {
if l.pos-l.start > 0 {
Expand Down
152 changes: 152 additions & 0 deletions named_args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,155 @@ func TestStrictNamedArgsRewriteQuery(t *testing.T) {
}
}
}

func TestNamedArgsRewriteQuery2(t *testing.T) {
t.Parallel()

for i, tt := range []struct {
sql string
args []any
namedArgs pgx.NamedArgs
expectedSQL string
expectedArgs []any
}{
{
sql: "select * from users where id = :id",
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: "select * from users where id = $1",
expectedArgs: []any{int32(42)},
},
{
sql: "select * from t where foo < :abc and baz = :def and bar < :abc",
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
expectedArgs: []any{int32(42), int32(1)},
},
{
sql: "select :a::int, :b::text",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "select $1::int, $2::text",
expectedArgs: []any{int32(42), "foo"},
},
{
sql: "select :Abc::int, :b_4::text, :_c::int",
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)},
expectedSQL: "select $1::int, $2::text, $3::int",
expectedArgs: []any{int32(42), "foo", int32(1)},
},
{
sql: "at end :",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "at end :",
expectedArgs: []any{},
},
{
sql: "ignores without valid character after : foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "ignores without valid character after : foo bar",
expectedArgs: []any{},
},
{
sql: "name cannot start with number :1 foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "name cannot start with number :1 foo bar",
expectedArgs: []any{},
},
{
sql: `select *, ':foo' as ":bar" from users where id = :id`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select *, ':foo' as ":bar" from users where id = $1`,
expectedArgs: []any{int32(42)},
},
{
sql: `select * -- :foo
from users -- :single line comments
where id = :id;`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select * -- :foo
from users -- :single line comments
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: `select * /* :multi line
:comment
*/
/* /* with :nesting */ */
from users
where id = :id;`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select * /* :multi line
:comment
*/
/* /* with :nesting */ */
from users
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: "extra provided argument",
namedArgs: pgx.NamedArgs{"extra": int32(1)},
expectedSQL: "extra provided argument",
expectedArgs: []any{},
},
{
sql: ":missing argument",
namedArgs: pgx.NamedArgs{},
expectedSQL: "$1 argument",
expectedArgs: []any{nil},
},

// test comments and quotes
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
require.NoError(t, err)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}

func TestStrictNamedArgsRewriteQuery2(t *testing.T) {
t.Parallel()

for i, tt := range []struct {
sql string
namedArgs pgx.StrictNamedArgs
expectedSQL string
expectedArgs []any
isExpectedError bool
}{
{
sql: "no arguments",
namedArgs: pgx.StrictNamedArgs{},
expectedSQL: "no arguments",
expectedArgs: []any{},
isExpectedError: false,
},
{
sql: ":all :matches",
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
expectedSQL: "$1 $2",
expectedArgs: []any{int32(1), int32(2)},
isExpectedError: false,
},
{
sql: "extra provided argument",
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
isExpectedError: true,
},
{
sql: ":missing argument",
namedArgs: pgx.StrictNamedArgs{},
isExpectedError: true,
},
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
if tt.isExpectedError {
assert.Errorf(t, err, "%d", i)
} else {
require.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
}

0 comments on commit 2aa476d

Please sign in to comment.