Skip to content

Commit

Permalink
Add support for custom JSON marshal and unmarshal.
Browse files Browse the repository at this point in the history
Fixes #2005.
  • Loading branch information
mitar committed May 13, 2024
1 parent 523411a commit a5f47de
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 32 deletions.
47 changes: 30 additions & 17 deletions pgtype/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@ import (
"reflect"
)

type JSONCodec struct{}
type JSONCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}

func (JSONCodec) FormatSupported(format int16) bool {
func (*JSONCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}

func (JSONCodec) PreferredFormat() int16 {
func (*JSONCodec) PreferredFormat() int16 {
return TextFormatCode
}

func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch value.(type) {
case string:
return encodePlanJSONCodecEitherFormatString{}
Expand All @@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
//
// https://github.com/jackc/pgx/issues/1681
case json.Marshaler:
return encodePlanJSONCodecEitherFormatMarshal{}
return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
}

// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
Expand All @@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
}
}

return encodePlanJSONCodecEitherFormatMarshal{}
return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
}

type encodePlanJSONCodecEitherFormatString struct{}
Expand Down Expand Up @@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt
return buf, nil
}

type encodePlanJSONCodecEitherFormatMarshal struct{}
type encodePlanJSONCodecEitherFormatMarshal struct {
marshal func(v any) ([]byte, error)
}

func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := json.Marshal(value)
func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := e.marshal(value)
if err != nil {
return nil, err
}
Expand All @@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new
return buf, nil
}

func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch target.(type) {
case *string:
return scanPlanAnyToString{}
Expand Down Expand Up @@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
return &scanPlanSQLScanner{formatCode: format}
}

return scanPlanJSONToJSONUnmarshal{}
return &scanPlanJSONToJSONUnmarshal{
unmarshal: c.Unmarshal,
}
}

type scanPlanAnyToString struct{}
Expand Down Expand Up @@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
return scanner.ScanBytes(src)
}

type scanPlanJSONToJSONUnmarshal struct{}
type scanPlanJSONToJSONUnmarshal struct {
unmarshal func(data []byte, v any) error
}

func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
if src == nil {
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() == reflect.Ptr {
Expand All @@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
elem := reflect.ValueOf(dst).Elem()
elem.Set(reflect.Zero(elem.Type()))

return json.Unmarshal(src, dst)
return s.unmarshal(src, dst)
}

func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
Expand All @@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
return dstBuf, nil
}

func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}

var dst any
err := json.Unmarshal(src, &dst)
err := c.Unmarshal(src, &dst)
return dst, err
}
18 changes: 18 additions & 0 deletions pgtype/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -224,3 +225,20 @@ func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
require.Equal(t, `{"custom":"thing"}`, jsonStr)
})
}

func TestJSONCodecCustomMarshal(t *testing.T) {
connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{
Marshal: func(v any) ([]byte, error) {
return []byte(`{"custom":"value"}`), nil
},
Unmarshal: json.Unmarshal,
}})
}

pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
{[]byte(`"hello"`), new(string), isExpectedEq(`{"custom":"value"}`)},
})
}
28 changes: 15 additions & 13 deletions pgtype/jsonb.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,31 @@ package pgtype

import (
"database/sql/driver"
"encoding/json"
"fmt"
)

type JSONBCodec struct{}
type JSONBCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}

func (JSONBCodec) FormatSupported(format int16) bool {
func (*JSONBCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}

func (JSONBCodec) PreferredFormat() int16 {
func (*JSONBCodec) PreferredFormat() int16 {
return TextFormatCode
}

func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format {
case BinaryFormatCode:
plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value)
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value)
if plan != nil {
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
}
case TextFormatCode:
return JSONCodec{}.PlanEncode(m, oid, format, value)
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value)
}

return nil
Expand All @@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne
return plan.textPlan.Encode(value, buf)
}

func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target)
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target)
if plan != nil {
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
}
case TextFormatCode:
return JSONCodec{}.PlanScan(m, oid, format, target)
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target)
}

return nil
Expand All @@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error {
return plan.textPlan.Scan(src[1:], dst)
}

func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
Expand All @@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
}
}

func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
Expand All @@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a
}

var dst any
err := json.Unmarshal(src, &dst)
err := c.Unmarshal(src, &dst)
return dst, err
}
4 changes: 2 additions & 2 deletions pgtype/pgtype_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
Expand Down

0 comments on commit a5f47de

Please sign in to comment.