Skip to content

Commit

Permalink
Merge pull request #3 from nyaruka/json
Browse files Browse the repository at this point in the history
add json nullable type
  • Loading branch information
nicpottier authored Jul 24, 2019
2 parents f32bd13 + 92dd275 commit ae74eab
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
61 changes: 61 additions & 0 deletions null.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,64 @@ func (m *Map) UnmarshalJSON(data []byte) error {
}
return json.Unmarshal(data, &m.m)
}

// JSON is a json.RawMessage that will marshall as null when empty or nil.
// null and {} values when unmashalled or scanned from a DB will result in a nil value
type JSON json.RawMessage

// Scan implements the Scanner interface for decoding from a database
func (j *JSON) Scan(src interface{}) error {
if src == nil {
*j = nil
return nil
}

var source []byte
switch src.(type) {
case string:
source = []byte(src.(string))
case []byte:
source = src.([]byte)
default:
return fmt.Errorf("incompatible type for JSON type")
}

if !json.Valid(source) {
return fmt.Errorf("invalid json: %s", source)
}
*j = source
return nil
}

// Value implements the driver Valuer interface
func (j JSON) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
return []byte(j), nil
}

// MarshalJSON encodes our JSON to JSON or null
func (j JSON) MarshalJSON() ([]byte, error) {
if len(j) == 0 {
return json.Marshal(nil)
}
return []byte(j), nil
}

// UnmarshalJSON sets our JSON from the passed in JSON
func (j *JSON) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
*j = nil
return nil
}

var jj json.RawMessage
err := json.Unmarshal(data, &jj)
if err != nil {
return err
}

*j = JSON(jj)
return nil
}
76 changes: 76 additions & 0 deletions null_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"database/sql/driver"
"encoding/json"
"strings"
"testing"

_ "github.com/lib/pq"
Expand Down Expand Up @@ -368,3 +369,78 @@ func TestMap(t *testing.T) {
assert.Equal(t, m2.GetString(tc.Key, ""), tc.KeyValue)
}
}

func TestJSON(t *testing.T) {
db, err := sql.Open("postgres", "postgres://localhost/null_test?sslmode=disable")
assert.NoError(t, err)

_, err = db.Exec(`DROP TABLE IF EXISTS json_test; CREATE TABLE json_test(value jsonb null);`)
assert.NoError(t, err)

sp := func(s string) *string {
return &s
}

tcs := []struct {
Value JSON
JSON json.RawMessage
DB *string
}{
{JSON(`{"foo":"bar"}`), json.RawMessage(`{"foo":"bar"}`), sp(`{"foo":"bar"}`)},
{JSON(nil), json.RawMessage(`null`), nil},
{JSON([]byte{}), json.RawMessage(`null`), nil},
}

for i, tc := range tcs {
// first test marshalling and unmarshalling to JSON
b, err := json.Marshal(tc.Value)
assert.NoError(t, err)
assert.Equal(t, string(tc.JSON), string(b), "%d: marshalled json not equal", i)

j := JSON("blah")
err = json.Unmarshal(tc.JSON, &j)
assert.NoError(t, err)
assert.Equal(t, string(tc.Value), string(j), "%d: unmarshalled json not equal", i)

// ok, now test writing and reading from DB
_, err = db.Exec(`DELETE FROM json_test;`)
assert.NoError(t, err)

_, err = db.Exec(`INSERT INTO json_test(value) VALUES($1)`, tc.DB)
assert.NoError(t, err)

rows, err := db.Query(`SELECT value FROM json_test;`)
assert.NoError(t, err)

assert.True(t, rows.Next())
j = JSON("blah")
err = rows.Scan(&j)
assert.NoError(t, err)

if tc.Value == nil {
assert.Nil(t, j, "%d: read db value should be null", i)
} else {
assert.Equal(t, string(tc.Value), strings.Replace(string(j), " ", "", -1), "%d: read db value should be equal", i)
}

_, err = db.Exec(`DELETE FROM json_test;`)
assert.NoError(t, err)

_, err = db.Exec(`INSERT INTO json_test(value) VALUES($1)`, tc.Value)
assert.NoError(t, err)

rows, err = db.Query(`SELECT value FROM json_test;`)
assert.NoError(t, err)

assert.True(t, rows.Next())
var s *string
err = rows.Scan(&s)
assert.NoError(t, err)

if tc.DB == nil {
assert.Nil(t, s, "%d: written db value should be null", i)
} else {
assert.Equal(t, *tc.DB, strings.Replace(*s, " ", "", -1), "%d: written db value should be equal", i)
}
}
}

0 comments on commit ae74eab

Please sign in to comment.