diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 8bf367c17..361532f01 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -419,12 +419,17 @@ func (w timeWrapper) DateValue() (Date, error) { return Date{Time: time.Time(w), Valid: true}, nil } -func (w *timeWrapper) ScanTimestamp(v Timestamp) error { +func (w *timeWrapper) ScanTimestamp(v Timestamp, infinityTsEnabled bool) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *time.Time") } - switch v.InfinityModifier { + infinityModifier := v.InfinityModifier + if infinityTsEnabled { + infinityModifier = Finite + } + + switch infinityModifier { case Finite: *w = timeWrapper(v.Time) return nil diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index eb9526725..5cb7406d9 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -276,7 +276,7 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) m.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) m.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) m.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) m.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) m.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) @@ -418,6 +418,24 @@ func NewMap() *Map { return m } +// EnableInfinityTs controls the handling of Postgres' "-infinity" and "infinity" "timestamp +func (m *Map) EnableInfinityTs(negativeInfinity, positiveInfinity time.Time) error { + if negativeInfinity.Unix() >= positiveInfinity.Unix() { + return errors.New("invalid timerange between negative and positive infinity") + } + + ts := m.nameToType["timestamp"] + tsc, _ := ts.Codec.(*TimestampCodec) + + tsc.InfinityTsEnabled = true + tsc.Min = negativeInfinity + tsc.Max = positiveInfinity + + delete(m.memoizedScanPlans, ts.OID) + delete(m.memoizedEncodePlans, ts.OID) + return nil +} + func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 5affc0c71..9d2db2309 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -11,6 +11,7 @@ import ( "regexp" "strconv" "testing" + "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" @@ -614,3 +615,76 @@ func isExpectedEq(a any) func(any) bool { return a == v } } + +func TestMapEnableInfinityTs(t *testing.T) { + type args struct { + negativeInfinity time.Time + positiveInfinity time.Time + } + + type want struct { + err bool + assertionFn func(*pgtype.Map) + } + + tests := []struct { + name string + args args + want want + }{ + { + name: "InfinityTs should be enabled", + args: args{ + negativeInfinity: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC), + positiveInfinity: time.Date(2100, 1, 1, 1, 0, 0, 0, time.UTC), + }, + want: want{ + err: false, + assertionFn: func(m *pgtype.Map) { + ts, exists := m.TypeForName("timestamp") + assert.True(t, exists) + + tsc, ok := ts.Codec.(*pgtype.TimestampCodec) + assert.True(t, ok) + + assert.True(t, tsc.InfinityTsEnabled) + }, + }, + }, + { + name: "Negative infinity should not be equal to positive infinity", + args: args{ + negativeInfinity: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC), + positiveInfinity: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC), + }, + want: want{ + err: true, + }, + }, + { + name: "Negative infinity should be lower than positive infinity", + args: args{ + negativeInfinity: time.Date(3000, 1, 1, 1, 0, 0, 0, time.UTC), + positiveInfinity: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC), + }, + want: want{ + err: true, + }, + }, + } + typeMap := pgtype.NewMap() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := typeMap.EnableInfinityTs(tt.args.negativeInfinity, tt.args.positiveInfinity) + if tt.want.err { + assert.Error(t, err, tt.name) + } else { + assert.NoError(t, err, tt.name) + } + + if tt.want.assertionFn != nil { + tt.want.assertionFn(typeMap) + } + }) + } +} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 9f3de2c59..ce0a34d2f 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -13,7 +13,7 @@ import ( const pgTimestampFormat = "2006-01-02 15:04:05.999999999" type TimestampScanner interface { - ScanTimestamp(v Timestamp) error + ScanTimestamp(v Timestamp, infinityTsEnabled bool) error } type TimestampValuer interface { @@ -27,7 +27,7 @@ type Timestamp struct { Valid bool } -func (ts *Timestamp) ScanTimestamp(v Timestamp) error { +func (ts *Timestamp) ScanTimestamp(v Timestamp, infinityTsEnabled bool) error { *ts = v return nil } @@ -66,7 +66,11 @@ func (ts Timestamp) Value() (driver.Value, error) { return ts.Time, nil } -type TimestampCodec struct{} +type TimestampCodec struct { + InfinityTsEnabled bool + Min time.Time + Max time.Time +} func (TimestampCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode @@ -76,24 +80,28 @@ func (TimestampCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestampValuer); !ok { return nil } switch format { case BinaryFormatCode: - return encodePlanTimestampCodecBinary{} + return encodePlanTimestampCodecBinary{infinityTsEnabled: c.InfinityTsEnabled, min: c.Min, max: c.Max} case TextFormatCode: - return encodePlanTimestampCodecText{} + return encodePlanTimestampCodecText{infinityTsEnabled: c.InfinityTsEnabled, min: c.Min, max: c.Max} } return nil } -type encodePlanTimestampCodecBinary struct{} +type encodePlanTimestampCodecBinary struct { + infinityTsEnabled bool + min time.Time + max time.Time +} -func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { +func (e encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { ts, err := value.(TimestampValuer).TimestampValue() if err != nil { return nil, err @@ -103,8 +111,17 @@ func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []by return nil, nil } + infinityModifier := ts.InfinityModifier + if e.infinityTsEnabled { + if ts.Time.Unix() <= e.min.Unix() { + infinityModifier = -Infinity + } else if ts.Time.Unix() >= e.max.Unix() { + infinityModifier = Infinity + } + } + var microsecSinceY2K int64 - switch ts.InfinityModifier { + switch infinityModifier { case Finite: t := discardTimeZone(ts.Time) microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 @@ -120,7 +137,11 @@ func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []by return buf, nil } -type encodePlanTimestampCodecText struct{} +type encodePlanTimestampCodecText struct { + infinityTsEnabled bool + min time.Time + max time.Time +} func (encodePlanTimestampCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { ts, err := value.(TimestampValuer).TimestampValue() @@ -170,31 +191,35 @@ func discardTimeZone(t time.Time) time.Time { return t } -func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanBinaryTimestampToTimestampScanner{} + return scanPlanBinaryTimestampToTimestampScanner{infinityTsEnabled: c.InfinityTsEnabled, min: c.Min, max: c.Max} } case TextFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanTextTimestampToTimestampScanner{} + return scanPlanTextTimestampToTimestampScanner{infinityTsEnabled: c.InfinityTsEnabled, min: c.Min, max: c.Max} } } return nil } -type scanPlanBinaryTimestampToTimestampScanner struct{} +type scanPlanBinaryTimestampToTimestampScanner struct { + infinityTsEnabled bool + min time.Time + max time.Time +} -func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (s scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { - return scanner.ScanTimestamp(Timestamp{}) + return scanner.ScanTimestamp(Timestamp{}, s.infinityTsEnabled) } if len(src) != 8 { @@ -206,9 +231,9 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error switch microsecSinceY2K { case infinityMicrosecondOffset: - ts = Timestamp{Valid: true, InfinityModifier: Infinity} + ts = Timestamp{Valid: true, InfinityModifier: Infinity, Time: s.max} case negativeInfinityMicrosecondOffset: - ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + ts = Timestamp{Valid: true, InfinityModifier: -Infinity, Time: s.min} default: tim := time.Unix( microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, @@ -217,25 +242,29 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error ts = Timestamp{Time: tim, Valid: true} } - return scanner.ScanTimestamp(ts) + return scanner.ScanTimestamp(ts, s.infinityTsEnabled) } -type scanPlanTextTimestampToTimestampScanner struct{} +type scanPlanTextTimestampToTimestampScanner struct { + infinityTsEnabled bool + min time.Time + max time.Time +} -func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (s scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { - return scanner.ScanTimestamp(Timestamp{}) + return scanner.ScanTimestamp(Timestamp{}, s.infinityTsEnabled) } var ts Timestamp sbuf := string(src) switch sbuf { case "infinity": - ts = Timestamp{Valid: true, InfinityModifier: Infinity} + ts = Timestamp{Valid: true, InfinityModifier: Infinity, Time: s.max} case "-infinity": - ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + ts = Timestamp{Valid: true, InfinityModifier: -Infinity, Time: s.min} default: bc := false if strings.HasSuffix(sbuf, " BC") { @@ -255,7 +284,7 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { ts = Timestamp{Time: tim, Valid: true} } - return scanner.ScanTimestamp(ts) + return scanner.ScanTimestamp(ts, s.infinityTsEnabled) } func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 849f55f66..f08edf72f 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -62,3 +62,45 @@ func TestTimestampCodecDecodeTextInvalid(t *testing.T) { err := plan.Scan([]byte(`eeeee`), &ts) require.Error(t, err) } + +func TestTimestampDecodeInfinity(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var inf time.Time + err := conn. + QueryRow(context.Background(), "select 'infinity'::timestamp"). + Scan(&inf) + require.Error(t, err, "Cannot decode infinite as timestamp. Use EnableInfinityTs to interpret inf to a min and max date") + + negInf, posInf := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) + jan1st2023 := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + conn.TypeMap().EnableInfinityTs(negInf, posInf) + + var min, max, tim time.Time + err = conn. + QueryRow(context.Background(), "select '-infinity'::timestamp, 'infinity'::timestamp, '2023-01-01T00:00:00Z'::timestamp"). + Scan(&min, &max, &tim) + + require.NoError(t, err, "Inf timestamp should be properly scanned when EnableInfinityTs() is valid") + require.Equal(t, negInf, min, "Negative infinity should be decoded as negative time supplied in EnableInfinityTs") + require.Equal(t, posInf, max, "Positive infinity should be decoded as positive time supplied in EnableInfinityTs") + require.Equal(t, tim, jan1st2023, "Normal timestamp should be properly decoded") + }) +} + +func TestTimestampEncodeInfinity(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + negInf, posInf := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) + conn.TypeMap().EnableInfinityTs(negInf, posInf) + + _, err := conn.Exec(ctx, "create temporary table tts(neg timestamp NOT NULL, pos timestamp NOT NULL)") + require.NoError(t, err, "Temp table creation should not cause an error") + + _, err = conn.Exec(ctx, "insert into tts(neg, pos) values($1, $2)", negInf, posInf) + require.NoError(t, err, "Inserting -infinity and infinity to temp tts table should not cause an error") + + var min, max string + conn.QueryRow(ctx, "select neg::text, pos::text from tts limit 1").Scan(&min, &max) + require.Equal(t, "-infinity", min, "Inserting {negInf} value to temp tts table should be converted to -infinity") + require.Equal(t, "infinity", max, "Inserting {posInf} value to temp tts table should be converted to infinity") + }) +}