diff --git a/subsetter/query.go b/subsetter/query.go index 8054b52..ed9dcd7 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -57,7 +57,7 @@ func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, } func CopyStringToTable(table string, data string, conn *pgx.Conn) (err error) { - q := fmt.Sprintf(`copy %s from stdout`, table) + q := fmt.Sprintf(`copy %s from stdin`, table) var buff bytes.Buffer buff.WriteString(data) if _, err = conn.PgConn().CopyFrom(context.Background(), &buff, q); err != nil { diff --git a/subsetter/query_test.go b/subsetter/query_test.go index dfad7c6..1df1459 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -57,7 +57,7 @@ func TestGetTablesWithRows(t *testing.T) { } } -func TestCopyRowToString(t *testing.T) { +func TestCopyTableToString(t *testing.T) { conn := getTestConnection() populateTests(conn) defer conn.Close(context.Background()) @@ -77,11 +77,11 @@ func TestCopyRowToString(t *testing.T) { t.Run(tt.name, func(t *testing.T) { gotResult, err := CopyTableToString(tt.table, 10, tt.conn) if (err != nil) != tt.wantErr { - t.Errorf("CopyRowToString() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("CopyTableToString() error = %v, wantErr %v", err, tt.wantErr) return } - if strings.Contains(gotResult, "test") != tt.wantResult { - t.Errorf("CopyRowToString() = %v, want %v", gotResult, tt.wantResult) + if strings.Contains(gotResult, "test") == tt.wantResult { + t.Errorf("CopyTableToString() = %v, want %v", gotResult, tt.wantResult) } }) } @@ -92,7 +92,6 @@ func TestCopyStringToTable(t *testing.T) { populateTests(conn) defer conn.Close(context.Background()) defer clearPopulateTests(conn) - populateTestsWithData(conn, "simple", 10) tests := []struct { name string @@ -103,7 +102,7 @@ func TestCopyStringToTable(t *testing.T) { wantErr bool }{ {"With tables", "simple", "cccc5f58-44d3-4d7a-bf37-a97d4f081a63 test\n", conn, 1, false}, - {"With more tables", "simple", "edcd63fe-303e-4d51-83ea-3fd00740ba2c test4\na170b0f5-3aec-469c-9589-cf25888a72e2 test7", conn, 2, false}, + {"With more tables", "simple", "edcd63fe-303e-4d51-83ea-3fd00740ba2c test4\na170b0f5-3aec-469c-9589-cf25888a72e2 test7", conn, 3, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -112,7 +111,8 @@ func TestCopyStringToTable(t *testing.T) { t.Errorf("CopyStringToTable() error = %v, wantErr %v", err, tt.wantErr) return } - if tt.wantResult == insertedRows(tt.table, tt.conn) { + gotInserted := insertedRows(tt.table, tt.conn) + if tt.wantResult != gotInserted { t.Errorf("CopyStringToTable() = %v, want %v", tt.wantResult, tt.wantResult) } @@ -121,11 +121,8 @@ func TestCopyStringToTable(t *testing.T) { } func insertedRows(s string, conn *pgx.Conn) int { - tables, _ := GetTablesWithRows(conn) - for _, table := range tables { - if table.Name == s { - return table.Rows - } - } - return 0 + q := "SELECT count(*) FROM " + s + var count int + conn.QueryRow(context.Background(), q).Scan(&count) + return count }