Skip to content

Commit

Permalink
change to generic map interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nicpottier committed Apr 3, 2019
1 parent cca5300 commit f32bd13
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 24 deletions.
56 changes: 42 additions & 14 deletions null.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,28 +138,56 @@ func (s String) Value() (driver.Value, error) {
return string(s), nil
}

// StringMap is a one level deep dictionary that is represented as JSON text in the database.
// Map is a one level deep dictionary that is represented as JSON text in the database.
// Empty maps will be written as null to the database and to JSON.
type StringMap struct {
m map[string]string
type Map struct {
m map[string]interface{}
}

// NewStringMap creates a new StringMap
func NewStringMap(m map[string]string) StringMap {
return StringMap{m: m}
// NewMap creates a new Map
func NewMap(m map[string]interface{}) Map {
return Map{m: m}
}

// Map returns our underlying map
func (m *StringMap) Map() map[string]string {
func (m *Map) Map() map[string]interface{} {
if m.m == nil {
m.m = make(map[string]string)
m.m = make(map[string]interface{})
}
return m.m
}

// GetString returns the string value with the passed in key, or def if not found or of wrong type
func (m *Map) GetString(key string, def string) string {
if m.m == nil {
return def
}
val := m.m[key]
if val == nil {
return def
}
str, isStr := val.(string)
if !isStr {
return def
}
return str
}

// Get returns the value with the passed in key, or def if not found
func (m *Map) Get(key string, def interface{}) interface{} {
if m.m == nil {
return def
}
val := m.m[key]
if val == nil {
return def
}
return val
}

// Scan implements the Scanner interface for decoding from a database
func (m *StringMap) Scan(src interface{}) error {
m.m = make(map[string]string)
func (m *Map) Scan(src interface{}) error {
m.m = make(map[string]interface{})
if src == nil {
return nil
}
Expand Down Expand Up @@ -187,24 +215,24 @@ func (m *StringMap) Scan(src interface{}) error {
}

// Value implements the driver Valuer interface
func (m StringMap) Value() (driver.Value, error) {
func (m Map) Value() (driver.Value, error) {
if m.m == nil || len(m.m) == 0 {
return nil, nil
}
return json.Marshal(m.m)
}

// MarshalJSON encodes our map to JSON
func (m StringMap) MarshalJSON() ([]byte, error) {
func (m Map) MarshalJSON() ([]byte, error) {
if m.m == nil || len(m.m) == 0 {
return json.Marshal(nil)
}
return json.Marshal(m.m)
}

// UnmarshalJSON sets our map from the passed in JSON
func (m *StringMap) UnmarshalJSON(data []byte) error {
m.m = make(map[string]string)
func (m *Map) UnmarshalJSON(data []byte) error {
m.m = make(map[string]interface{})
if len(data) == 0 {
return nil
}
Expand Down
25 changes: 15 additions & 10 deletions null_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,16 @@ func TestMap(t *testing.T) {
}

tcs := []struct {
Value StringMap
JSON string
DB *string
Value Map
JSON string
DB *string
Key string
KeyValue string
}{
{NewStringMap(map[string]string{"foo": "bar"}), `{"foo":"bar"}`, sp(`{"foo": "bar"}`)},
{NewStringMap(map[string]string{}), "null", nil},
{NewStringMap(nil), "null", nil},
{NewStringMap(nil), "null", sp("")},
{NewMap(map[string]interface{}{"foo": "bar"}), `{"foo":"bar"}`, sp(`{"foo": "bar"}`), "foo", "bar"},
{NewMap(map[string]interface{}{}), "null", nil, "foo", ""},
{NewMap(nil), "null", nil, "foo", ""},
{NewMap(nil), "null", sp(""), "foo", ""},
}

for i, tc := range tcs {
Expand All @@ -328,23 +330,25 @@ func TestMap(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, tc.JSON, string(b), "%d: %s not equal to %s", i, tc.JSON, string(b))

m := StringMap{}
m := Map{}
err = json.Unmarshal(b, &m)
assert.NoError(t, err)
assert.Equal(t, tc.Value.Map(), m.Map(), "%d: %s not equal to %s", i, tc.Value, m)
assert.Equal(t, m.GetString(tc.Key, ""), tc.KeyValue)

_, err = db.Exec(`INSERT INTO map(value) VALUES($1)`, tc.Value)
assert.NoError(t, err)

rows, err := db.Query(`SELECT value FROM map;`)
assert.NoError(t, err)

m2 := StringMap{}
m2 := Map{}
assert.True(t, rows.Next())
err = rows.Scan(&m2)
assert.NoError(t, err)

assert.Equal(t, tc.Value.Map(), m2.Map())
assert.Equal(t, m2.GetString(tc.Key, ""), tc.KeyValue)

_, err = db.Exec(`DELETE FROM map;`)
assert.NoError(t, err)
Expand All @@ -355,11 +359,12 @@ func TestMap(t *testing.T) {
rows, err = db.Query(`SELECT value FROM map;`)
assert.NoError(t, err)

m2 = StringMap{}
m2 = Map{}
assert.True(t, rows.Next())
err = rows.Scan(&m2)
assert.NoError(t, err)

assert.Equal(t, tc.Value.Map(), m2.Map())
assert.Equal(t, m2.GetString(tc.Key, ""), tc.KeyValue)
}
}

0 comments on commit f32bd13

Please sign in to comment.