diff --git a/dbutil/scan.go b/dbutil/scan.go index 3683a97..312a616 100644 --- a/dbutil/scan.go +++ b/dbutil/scan.go @@ -10,16 +10,20 @@ import ( var validate = validator.New() +// Scanable is an interface to allow scanning of sql.Row or sql.Rows +type Scannable interface { + Scan(dest ...any) error +} + // ScanJSON scans a row which is JSON into a destination struct -func ScanJSON(rows *sql.Rows, destination any) error { +func ScanJSON(src Scannable, dest any) error { var raw json.RawMessage - err := rows.Scan(&raw) - if err != nil { + + if err := src.Scan(&raw); err != nil { return fmt.Errorf("error scanning row JSON: %w", err) } - err = json.Unmarshal(raw, destination) - if err != nil { + if err := json.Unmarshal(raw, dest); err != nil { return fmt.Errorf("error unmarshalling row JSON: %w", err) } @@ -27,12 +31,12 @@ func ScanJSON(rows *sql.Rows, destination any) error { } // ScanAndValidateJSON scans a row which is JSON into a destination struct and validates it -func ScanAndValidateJSON(rows *sql.Rows, destination any) error { - if err := ScanJSON(rows, destination); err != nil { +func ScanAndValidateJSON(src Scannable, dest any) error { + if err := ScanJSON(src, dest); err != nil { return err } - err := validate.Struct(destination) + err := validate.Struct(dest) if err != nil { return fmt.Errorf("error validating unmarsalled JSON: %w", err) } diff --git a/dbutil/scan_test.go b/dbutil/scan_test.go index 7907a31..de4a96d 100644 --- a/dbutil/scan_test.go +++ b/dbutil/scan_test.go @@ -96,6 +96,12 @@ func TestScanJSON(t *testing.T) { rows.Close() + // can also scan as a single row + row := db.QueryRowContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 3) r`) + err = dbutil.ScanJSON(row, f) + assert.NoError(t, err) + assert.Equal(t, "a5850c89-dd29-46f6-9de1-d068b3c2db94", f.UUID) + // can all scan all rows with ScanAllJSON rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f) r`)