diff --git a/internal/anynil/anynil.go b/internal/anynil/anynil.go deleted file mode 100644 index 061680e08..000000000 --- a/internal/anynil/anynil.go +++ /dev/null @@ -1,46 +0,0 @@ -package anynil - -import ( - "database/sql/driver" - "reflect" -) - -// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil -// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual -// type. Yuck. -// -// This can be simplified in Go 1.22 with reflect.TypeFor. -// -// var valuerReflectType = reflect.TypeFor[driver.Valuer]() -var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - -// Is returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement -// driver.Valuer if it is only implemented by T. -func Is(value any) bool { - if value == nil { - return true - } - - refVal := reflect.ValueOf(value) - kind := refVal.Kind() - switch kind { - case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: - if !refVal.IsNil() { - return false - } - - if _, ok := value.(driver.Valuer); ok { - if kind == reflect.Ptr { - // The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on T - // to see if it is not implemented on *T. - return refVal.Type().Elem().Implements(valuerReflectType) - } else { - return false - } - } - - return true - default: - return false - } -} diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index c1863b32a..bf5f6989a 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -6,7 +6,6 @@ import ( "fmt" "reflect" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -230,7 +229,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan // target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the // scan of the elements. - if anynil.Is(target) { + if isNil, _ := isNilDriverValuer(target); isNil { arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter) } diff --git a/pgtype/doc.go b/pgtype/doc.go index 2039fcf1d..d56c1dc70 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -141,11 +141,11 @@ interfaces. Encoding Typed Nils -pgtype normalizes typed nils (e.g. []byte(nil)) into nil. nil is always encoded is the SQL NULL value without going -through the Codec system. This means that Codecs and other encoding logic does not have to handle nil or *T(nil). +pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec +system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil). However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore, -driver.Valuer values are not normalized to nil unless it is a *T(nil) where driver.Valuer is implemented on T. See +driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See https://github.com/golang/go/issues/8415 and https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 810eb5771..408295683 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -9,8 +9,6 @@ import ( "net/netip" "reflect" "time" - - "github.com/jackc/pgx/v5/internal/anynil" ) // PostgreSQL oids for common types @@ -1914,8 +1912,17 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) { - if anynil.Is(value) { - return nil, nil + if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil { + if callNilDriverValuer { + newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf) + if err != nil { + return nil, newEncodeError(value, m, oid, formatCode, err) + } + + return newBuf, nil + } else { + return nil, nil + } } plan := m.PlanEncode(oid, formatCode, value) @@ -1970,3 +1977,55 @@ func (w *sqlScannerWrapper) Scan(src any) error { return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) } + +// canBeNil returns true if value can be nil. +func canBeNil(value any) bool { + refVal := reflect.ValueOf(value) + kind := refVal.Kind() + switch kind { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + return true + default: + return false + } +} + +// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil +// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual +// type. Yuck. +// +// This can be simplified in Go 1.22 with reflect.TypeFor. +// +// var valuerReflectType = reflect.TypeFor[driver.Valuer]() +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement +// driver.Valuer if it is only implemented by T. +func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) { + if value == nil { + return true, false + } + + refVal := reflect.ValueOf(value) + kind := refVal.Kind() + switch kind { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + if !refVal.IsNil() { + return false, false + } + + if _, ok := value.(driver.Valuer); ok { + if kind == reflect.Ptr { + // The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T + // by checking if it is not implemented on *T. + return true, !refVal.Type().Elem().Implements(valuerReflectType) + } else { + return true, true + } + } + + return true, false + default: + return false, false + } +}