diff --git a/README.md b/README.md index 49f2c3d78..0cf2c2916 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube. ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.20 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy diff --git a/conn.go b/conn.go index a9cb3163f..311721459 100644 --- a/conn.go +++ b/conn.go @@ -10,7 +10,6 @@ import ( "strings" "time" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/sanitize" "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" @@ -755,7 +754,6 @@ optionLoop: } c.eqb.reset() - anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) var err error diff --git a/extended_query_builder.go b/extended_query_builder.go index 9c9de5b2c..526b0e953 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -1,10 +1,8 @@ package pgx import ( - "database/sql/driver" "fmt" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" ) @@ -23,10 +21,15 @@ type ExtendedQueryBuilder struct { func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { eqb.reset() - anynil.NormalizeSlice(args) - if sd == nil { - return eqb.appendParamsForQueryExecModeExec(m, args) + for i := range args { + err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i]) + if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %w", i, err) + return err + } + } + return nil } if len(sd.ParamOIDs) != len(args) { @@ -113,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() { } func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { - if anynil.Is(arg) { - return nil, nil - } - if eqb.paramValueBytes == nil { eqb.paramValueBytes = make([]byte, 0, 128) } @@ -145,74 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui return m.FormatCodeForOID(oid) } - -// appendParamsForQueryExecModeExec appends the args to eqb. -// -// Parameters must be encoded in the text format because of differences in type conversion between timestamps and -// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the -// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both -// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL -// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. -// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion -// before converting it to date. This means that dates can be shifted by one day. In text format without that double -// type conversion it takes the date directly and ignores time zone (i.e. it works). -// -// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is -// no way to safely use binary or to specify the parameter OIDs. -func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error { - for _, arg := range args { - if arg == nil { - err := eqb.appendParam(m, 0, TextFormatCode, arg) - if err != nil { - return err - } - } else { - dt, ok := m.TypeForValue(arg) - if !ok { - var tv pgtype.TextValuer - if tv, ok = arg.(pgtype.TextValuer); ok { - t, err := tv.TextValue() - if err != nil { - return err - } - - dt, ok = m.TypeForOID(pgtype.TextOID) - if ok { - arg = t - } - } - } - if !ok { - var dv driver.Valuer - if dv, ok = arg.(driver.Valuer); ok { - v, err := dv.Value() - if err != nil { - return err - } - dt, ok = m.TypeForValue(v) - if ok { - arg = v - } - } - } - if !ok { - var str fmt.Stringer - if str, ok = arg.(fmt.Stringer); ok { - dt, ok = m.TypeForOID(pgtype.TextOID) - if ok { - arg = str.String() - } - } - } - if !ok { - return &unknownArgumentTypeQueryExecModeExecError{arg: arg} - } - err := eqb.appendParam(m, dt.OID, TextFormatCode, arg) - if err != nil { - return err - } - } - } - - return nil -} diff --git a/go.mod b/go.mod index c8430a417..e27bf8a48 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/jackc/pgx/v5 -go 1.19 +go 1.20 require ( github.com/jackc/pgpassfile v1.0.0 diff --git a/internal/anynil/anynil.go b/internal/anynil/anynil.go deleted file mode 100644 index 9a48c1a84..000000000 --- a/internal/anynil/anynil.go +++ /dev/null @@ -1,36 +0,0 @@ -package anynil - -import "reflect" - -// Is returns true if value is any type of nil. e.g. nil or []byte(nil). -func Is(value any) bool { - if value == nil { - return true - } - - refVal := reflect.ValueOf(value) - switch refVal.Kind() { - case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: - return refVal.IsNil() - default: - return false - } -} - -// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified. -func Normalize(v any) any { - if Is(v) { - return nil - } - return v -} - -// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is -// mutated in place. -func NormalizeSlice(s []any) { - for i := range s { - if Is(s[i]) { - s[i] = nil - } - } -} diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index c1863b32a..bf5f6989a 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -6,7 +6,6 @@ import ( "fmt" "reflect" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -230,7 +229,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan // target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the // scan of the elements. - if anynil.Is(target) { + if isNil, _ := isNilDriverValuer(target); isNil { arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter) } diff --git a/pgtype/doc.go b/pgtype/doc.go index ec9270acb..d56c1dc70 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -139,6 +139,16 @@ Compatibility with database/sql pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer interfaces. +Encoding Typed Nils + +pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec +system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil). + +However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore, +driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See +https://github.com/golang/go/issues/8415 and +https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. + Child Records pgtype's support for arrays and composite records can be used to load records and their children in a single query. See diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 2be11e820..408295683 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1912,8 +1912,17 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil + if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil { + if callNilDriverValuer { + newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf) + if err != nil { + return nil, newEncodeError(value, m, oid, formatCode, err) + } + + return newBuf, nil + } else { + return nil, nil + } } plan := m.PlanEncode(oid, formatCode, value) @@ -1968,3 +1977,55 @@ func (w *sqlScannerWrapper) Scan(src any) error { return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) } + +// canBeNil returns true if value can be nil. +func canBeNil(value any) bool { + refVal := reflect.ValueOf(value) + kind := refVal.Kind() + switch kind { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + return true + default: + return false + } +} + +// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil +// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual +// type. Yuck. +// +// This can be simplified in Go 1.22 with reflect.TypeFor. +// +// var valuerReflectType = reflect.TypeFor[driver.Valuer]() +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement +// driver.Valuer if it is only implemented by T. +func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) { + if value == nil { + return true, false + } + + refVal := reflect.ValueOf(value) + kind := refVal.Kind() + switch kind { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + if !refVal.IsNil() { + return false, false + } + + if _, ok := value.(driver.Valuer); ok { + if kind == reflect.Ptr { + // The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T + // by checking if it is not implemented on *T. + return true, !refVal.Type().Elem().Implements(valuerReflectType) + } else { + return true, true + } + } + + return true, false + default: + return false, false + } +} diff --git a/query_test.go b/query_test.go index df044cdea..a6a26ad77 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" + "encoding/json" "errors" "fmt" "os" @@ -1171,6 +1173,161 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes ensureConnValid(t, conn) } +type nilPointerAsEmptyJSONObject struct { + ID string + Name string +} + +func (v *nilPointerAsEmptyJSONObject) Value() (driver.Value, error) { + if v == nil { + return "{}", nil + } + + return json.Marshal(v) +} + +// https://github.com/jackc/pgx/issues/1566 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilPointerImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v *nilPointerAsEmptyJSONObject + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "{}", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = &nilPointerAsEmptyJSONObject{ID: "1", Name: "foo"} + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 *nilPointerAsEmptyJSONObject + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +type nilSliceAsEmptySlice []byte + +func (j nilSliceAsEmptySlice) Value() (driver.Value, error) { + if len(j) == 0 { + return []byte("[]"), nil + } + + return []byte(j), nil +} + +func (j *nilSliceAsEmptySlice) UnmarshalJSON(data []byte) error { + *j = bytes.Clone(data) + return nil +} + +// https://github.com/jackc/pgx/issues/1860 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilSliceImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v nilSliceAsEmptySlice + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "[]", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = nilSliceAsEmptySlice(`{"name": "foo"}`) + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 nilSliceAsEmptySlice + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +type nilMapAsEmptyObject map[string]any + +func (j nilMapAsEmptyObject) Value() (driver.Value, error) { + if j == nil { + return []byte("{}"), nil + } + + return json.Marshal(j) +} + +func (j *nilMapAsEmptyObject) UnmarshalJSON(data []byte) error { + var m map[string]any + err := json.Unmarshal(data, &m) + if err != nil { + return err + } + + *j = m + + return nil +} + +// https://github.com/jackc/pgx/pull/2019#discussion_r1605806751 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilMapImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v nilMapAsEmptyObject + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "{}", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = nilMapAsEmptyObject{"name": "foo"} + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 nilMapAsEmptyObject + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + func TestConnQueryDatabaseSQLDriverScannerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index cab717d0a..6e2ff3003 100644 --- a/values.go +++ b/values.go @@ -3,7 +3,6 @@ package pgx import ( "errors" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" ) @@ -15,10 +14,6 @@ const ( ) func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { - if anynil.Is(arg) { - return nil, nil - } - buf, err := m.Encode(0, TextFormatCode, arg, []byte{}) if err != nil { return nil, err @@ -30,10 +25,6 @@ func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { } func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { - if anynil.Is(arg) { - return pgio.AppendInt32(buf, -1), nil - } - sp := len(buf) buf = pgio.AppendInt32(buf, -1) argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)