Skip to content

Commit

Permalink
Convert null.Map to be generic
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Sep 6, 2023
1 parent 0b9f816 commit 7396e49
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ you marshal the zero value to JSON, you will get `null`.
| `null.Int` | `int(0)`
| `null.Int64` | `int64(0)`
| `null.String` | `""`
| `null.Map` | `map[string]any{}`
| `null.Map[V]` | `map[string]V{}`
| `null.JSON` | `[]byte("null")`

If you want to define a custom integer type, you need to define the following methods:
Expand Down
28 changes: 14 additions & 14 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ import (
)

// Map is a generic map which is written to the database as JSON.
type Map map[string]any
type Map[V any] map[string]V

// Scan implements the Scanner interface
func (m *Map) Scan(value any) error { return ScanMap(value, m) }
func (m *Map[V]) Scan(value any) error { return ScanMap(value, m) }

// Value implements the Valuer interface
func (m Map) Value() (driver.Value, error) { return MapValue(m) }
func (m Map[V]) Value() (driver.Value, error) { return MapValue(m) }

// UnmarshalJSON implements the Unmarshaller interface
func (m *Map) UnmarshalJSON(data []byte) error { return UnmarshalMap(data, m) }
func (m *Map[V]) UnmarshalJSON(data []byte) error { return UnmarshalMap(data, m) }

// MarshalJSON implements the Marshaller interface
func (m Map) MarshalJSON() ([]byte, error) { return MarshalMap(m) }
func (m Map[V]) MarshalJSON() ([]byte, error) { return MarshalMap(m) }

// ScanMap scans a nullable text or JSON into a map, using an empty map for NULL.
func ScanMap(value any, m *Map) error {
func ScanMap[V any](value any, m *Map[V]) error {
if value == nil {
*m = make(Map)
*m = make(Map[V])
return nil
}

Expand All @@ -40,7 +40,7 @@ func ScanMap(value any, m *Map) error {

// empty bytes is same as nil
if len(raw) == 0 {
*m = make(Map)
*m = make(Map[V])
return nil
}

Expand All @@ -52,29 +52,29 @@ func ScanMap(value any, m *Map) error {
}

// MapValue converts a map to NULL if it is empty.
func MapValue(m Map) (driver.Value, error) {
func MapValue[V any](m Map[V]) (driver.Value, error) {
if len(m) == 0 {
return nil, nil
}
return json.Marshal(m)
}

// MarshalMap marshals a map, returning null for an empty map.
func MarshalMap(m Map) ([]byte, error) {
func MarshalMap[V any](m Map[V]) ([]byte, error) {
if len(m) == 0 {
return json.Marshal(nil)
}
return json.Marshal(map[string]any(m))
return json.Marshal(map[string]V(m))
}

func UnmarshalMap(data []byte, m *Map) error {
err := json.Unmarshal(data, (*map[string]any)(m))
func UnmarshalMap[V any](data []byte, m *Map[V]) error {
err := json.Unmarshal(data, (*map[string]V)(m))
if err != nil {
return err
}

if *m == nil {
*m = make(Map) // initialize empty map
*m = make(Map[V]) // initialize empty map
}
return nil
}
14 changes: 7 additions & 7 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ func TestMap(t *testing.T) {

testMap := func() {
tcs := []struct {
value null.Map
value null.Map[string]
dbValue driver.Value
marshaled []byte
}{
{null.Map{"foo": "bar"}, []byte(`{"foo":"bar"}`), []byte(`{"foo":"bar"}`)},
{null.Map{}, nil, []byte(`null`)},
{null.Map(nil), nil, []byte(`null`)},
{null.Map[string]{"foo": "bar"}, []byte(`{"foo":"bar"}`), []byte(`{"foo":"bar"}`)},
{null.Map[string]{}, nil, []byte(`null`)},
{null.Map[string](nil), nil, []byte(`null`)},
}

for _, tc := range tcs {
Expand All @@ -38,15 +38,15 @@ func TestMap(t *testing.T) {
rows, err := db.Query(`SELECT value FROM test;`)
assert.NoError(t, err)

scanned := null.Map{}
scanned := null.Map[string]{}
assert.True(t, rows.Next())
err = rows.Scan(&scanned)
assert.NoError(t, err)

// we never return a nil map even if that's what we wrote
expected := tc.value
if expected == nil {
expected = null.Map{}
expected = null.Map[string]{}
}

assert.Equal(t, expected, scanned, "scanned value mismatch for %v", tc.value)
Expand All @@ -55,7 +55,7 @@ func TestMap(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, tc.marshaled, marshaled, "marshaled mismatch for %v", tc.value)

unmarshaled := null.Map{}
unmarshaled := null.Map[string]{}
err = json.Unmarshal(marshaled, &unmarshaled)
assert.NoError(t, err)
assert.Equal(t, expected, unmarshaled, "unmarshaled mismatch for %v", tc.value)
Expand Down

0 comments on commit 7396e49

Please sign in to comment.