diff --git a/.gitignore b/.gitignore index 6f0de07..01b41f1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .pgsql/data bin dist/ +*.sql +.nix-profile* \ No newline at end of file diff --git a/.goreleaser.yaml b/.goreleaser.yaml index e7b8aa6..9b55360 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,17 +1,22 @@ before: hooks: - go mod tidy +universal_binaries: + - replace: true + builds: - env: - CGO_ENABLED=0 goos: - linux - - windows - darwin main: ./cli - + goarch: + - amd64 + - arm64 archives: - format: tar.gz + strip_parent_binary_folder: true # this name template makes the OS and Arch compatible with the results of uname. name_template: >- {{ .ProjectName }}_ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b176ff8 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "spellright.language": [ + "en" + ], + "spellright.documentTypes": [ + "markdown", + "latex", + "plaintext", + "go" + ] +} \ No newline at end of file diff --git a/Makefile b/Makefile index 8771f2c..7356abb 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ is-postgres-running: .PHONY: pgweb pgweb:is-postgres-running - @pgweb --url "postgres://test_source@localhost:5432/test_source?sslmode=disable" + @pgweb --url "postgres://test_target@localhost:5432/test_target?sslmode=disable" build: rm -rf dist diff --git a/README.md b/README.md index fb17185..2501250 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,14 @@ [![lint](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml) [![build](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml) [![vuln](https://github.com/teamniteo/pg-subsetter/actions/workflows/vuln.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/vuln.yml) -`pg-subsetter` is a tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA. This means that your target database has to have schema populated in some other way. +`pg-subsetter` is a tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA. + ### Database Fraction Synchronization `pg-subsetter` allows you to select and sync a specific subset of your database. Whether it's a fraction of a table or a particular dataset, you can have it replicated in another database without synchronizing the entire DB. ### Integrity Preservation with Foreign Keys -Foreign keys play a vital role in maintaining the relationships between tables. `pg-subsetter` ensures that all foreign keys are handled correctly during the synchronization process, maintaining the integrity and relationships of the data. +Foreign keys play a vital role in maintaining the relationships between tables. `pg-subsetter` ensures that all foreign keys(one-to-one, one-to many, many-to-many) are handled correctly during the synchronization process, maintaining the integrity and relationships of the data. ### Efficient COPY Method Utilizing the native PostgreSQL COPY command, `pg-subsetter` performs data transfer with high efficiency. This method significantly speeds up the synchronization process, minimizing downtime and resource consumption. @@ -17,31 +18,50 @@ Utilizing the native PostgreSQL COPY command, `pg-subsetter` performs data trans ### Stateless Operation `pg-subsetter` is built to be stateless, meaning it does not maintain any internal state between runs. This ensures that each synchronization process is independent, enhancing reliability and making it easier to manage and scale. +### Sync required rows +`pg-subsetter` can be instructed to copy certain rows in specific tables, the command can be used multiple times to sync more data. ## Usage ``` Usage of subsetter: -dst string - Destination database DSN + Destination database DSN + -exclude value + Query to ignore tables 'users: id = 1', can be used multiple times -f float - Fraction of rows to copy (default 0.05) - -force value - Query to copy required tables (users: id = 1) + Fraction of rows to copy (default 0.05) + -include value + Query to copy required tables 'users: id = 1', can be used multiple times -src string - Source database DSN + Source database DSN + -v Release information + -verbose + Show more information during sync (default true) ``` ### Example -Copy a fraction of the database and force certain rows to be also copied over. + +Prepare schema in target database: + +```bash +pg_dump --schema-only -n public -f schemadump.sql "postgres://test_source@localhost:5432/test_source?sslmode=disable" +psql -f schemadump.sql "postgres://test_target@localhost:5432/test_target?sslmode=disable" +``` + +Copy a fraction of the database and force certain rows to be also copied over: ``` pg-subsetter \ -src "postgres://test_source@localhost:5432/test_source?sslmode=disable" \ -dst "postgres://test_target@localhost:5432/test_target?sslmode=disable" \ - -f 0.05 + -f 0.5 + -include "user: id=1" + -include "group: id=1" + -exclude "domains: domain_name ilike '%.si'" + ``` # Installing diff --git a/cli/extra.go b/cli/extra.go new file mode 100644 index 0000000..29f1bd5 --- /dev/null +++ b/cli/extra.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "strings" + + "niteo.co/subsetter/subsetter" +) + +type arrayExtra []subsetter.Rule + +func (i *arrayExtra) String() string { + return fmt.Sprintf("%v", *i) +} + +func (i *arrayExtra) Set(value string) error { + q := strings.Split(strings.TrimSpace(value), ":") + + *i = append(*i, subsetter.Rule{ + Table: strings.TrimSpace(q[0]), + Where: strings.TrimSpace(q[1]), + }) + return nil +} diff --git a/cli/extra_test.go b/cli/extra_test.go new file mode 100644 index 0000000..96812ca --- /dev/null +++ b/cli/extra_test.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "testing" +) + +func Test_arrayExtra_Set(t *testing.T) { + + tests := []struct { + name string + value string + rules arrayExtra + wantErr bool + }{ + {"With tables", "simple: id < 10", arrayExtra{{Table: "simple", Where: "id < 10"}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := arrayExtra{} + if err := r.Set(tt.value); (err != nil) != tt.wantErr { + t.Errorf("arrayExtra.Set() error = %v, wantErr %v", err, tt.wantErr) + } + if fmt.Sprintf("%v", r) != fmt.Sprintf("%v", tt.rules) { + t.Errorf("arrayExtra.Set() = %v, want %v", r, tt.rules) + } + }) + } +} diff --git a/cli/force.go b/cli/force.go deleted file mode 100644 index 56a4635..0000000 --- a/cli/force.go +++ /dev/null @@ -1,24 +0,0 @@ -package main - -import ( - "fmt" - "strings" - - "niteo.co/subsetter/subsetter" -) - -type arrayForce []subsetter.Force - -func (i *arrayForce) String() string { - return fmt.Sprintf("%v", *i) -} - -func (i *arrayForce) Set(value string) error { - q := strings.SplitAfter(strings.TrimSpace(value), ":") - - *i = append(*i, subsetter.Force{ - Table: q[0], - Where: q[1], - }) - return nil -} diff --git a/cli/main.go b/cli/main.go index cdd6c27..7adebd8 100644 --- a/cli/main.go +++ b/cli/main.go @@ -10,19 +10,33 @@ import ( "niteo.co/subsetter/subsetter" ) +var ( + version = "dev" + commit = "none" + date = "unknown" +) + var src = flag.String("src", "", "Source database DSN") var dst = flag.String("dst", "", "Destination database DSN") var fraction = flag.Float64("f", 0.05, "Fraction of rows to copy") var verbose = flag.Bool("verbose", true, "Show more information during sync") -var forceSync arrayForce +var ver = flag.Bool("v", false, "Release information") +var extraInclude arrayExtra +var extraExclude arrayExtra func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack - flag.Var(&forceSync, "force", "Query to copy required tables (users: id = 1)") + flag.Var(&extraInclude, "include", "Query to copy required tables 'users: id = 1', can be used multiple times") + flag.Var(&extraExclude, "exclude", "Query to ignore tables 'users: id = 1', can be used multiple times") flag.Parse() + if *ver { + log.Info().Str("version", version).Str("commit", commit).Str("date", date).Msg("Version") + os.Exit(0) + } + if *src == "" || *dst == "" { log.Fatal().Msg("Source and destination DSNs are required") } @@ -31,20 +45,23 @@ func main() { log.Fatal().Msg("Fraction must be between 0 and 1") } - if len(forceSync) > 0 { - log.Info().Str("forced", forceSync.String()).Msg("Forcing sync for tables") + if len(extraInclude) > 0 { + log.Info().Str("include", extraInclude.String()).Msg("Forcibly including") + } + if len(extraExclude) > 0 { + log.Info().Str("exclude", extraExclude.String()).Msg("Forcibly ignoring") } - s, err := subsetter.NewSync(*src, *dst, *fraction, forceSync, *verbose) + s, err := subsetter.NewSync(*src, *dst, *fraction, extraInclude, extraExclude, *verbose) if err != nil { - log.Fatal().Stack().Err(err).Msg("Failed to configure sync") + log.Fatal().Err(err).Msg("Failed to configure sync") } defer s.Close() err = s.Sync() if err != nil { - log.Fatal().Stack().Err(err).Msg("Failed to sync") + log.Fatal().Err(err).Msg("Failed to sync") } } diff --git a/flake.lock b/flake.lock index a95c7d6..1231073 100644 --- a/flake.lock +++ b/flake.lock @@ -20,16 +20,16 @@ }, "nixpkgs": { "locked": { - "lastModified": 1691674635, - "narHash": "sha256-dTWUqEf7lb7k67cFZXIG7Xe9ES6XEvKzzUT24A4hGa4=", + "lastModified": 1691709280, + "narHash": "sha256-zmfH2OlZEXwv572d0g8f6M5Ac6RiO8TxymOpY3uuqrM=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "1eef5102c9fcb3281fbf94de90e7d59c92664373", + "rev": "cf73a86c35a84de0e2f3ba494327cf6fb51c0dfd", "type": "github" }, "original": { "owner": "NixOS", - "ref": "release-23.05", + "ref": "nixpkgs-unstable", "repo": "nixpkgs", "type": "github" } diff --git a/flake.nix b/flake.nix index 6367fbc..76b7e50 100644 --- a/flake.nix +++ b/flake.nix @@ -3,7 +3,7 @@ allowed-users = [ "@wheel" "@staff" ]; # allow compiling on every device/machine }; inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/release-23.05"; + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; flake-parts.url = "github:hercules-ci/flake-parts"; }; outputs = inputs@{ self, nixpkgs, flake-parts, ... }: @@ -35,7 +35,7 @@ go goreleaser golangci-lint - postgresql + postgresql_15 process-compose nixpkgs-fmt pgweb diff --git a/go.mod b/go.mod index 105b568..e75a2ee 100644 --- a/go.mod +++ b/go.mod @@ -2,18 +2,21 @@ module niteo.co/subsetter go 1.20 -require github.com/rs/zerolog v1.30.0 +require ( + github.com/pkg/errors v0.9.1 + github.com/rs/zerolog v1.30.0 +) require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect - github.com/pkg/errors v0.9.1 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/sync v0.1.0 // indirect ) require ( + github.com/davecgh/go-spew v1.1.1 github.com/jackc/pgx/v5 v5.4.3 github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect diff --git a/go.sum b/go.sum index 95de6cf..131fcde 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,7 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= diff --git a/subsetter/query.go b/subsetter/query.go index 546d45d..44148e1 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -4,8 +4,10 @@ import ( "bytes" "context" "fmt" + "strings" "github.com/jackc/pgx/v5/pgxpool" + "github.com/samber/lo" ) type Table struct { @@ -14,6 +16,16 @@ type Table struct { Relations []Relation } +func (t *Table) RelationNames() (names string) { + rel := lo.Map(t.Relations, func(r Relation, _ int) string { + return r.PrimaryTable + ">" + r.PrimaryColumn + }) + if len(rel) > 0 { + return strings.Join(rel, ", ") + } + return "none" +} + func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { q := `SELECT relname, @@ -46,6 +58,27 @@ func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { return } +func GetKeys(q string, conn *pgxpool.Pool) (ids []string, err error) { + rows, err := conn.Query(context.Background(), q) + for rows.Next() { + var id string + + if err := rows.Scan(&id); err == nil { + ids = append(ids, id) + } + + } + rows.Close() + + return +} + +func DeleteRows(table string, where string, conn *pgxpool.Pool) (err error) { + q := fmt.Sprintf(`DELETE FROM %s WHERE %s`, table, where) + _, err = conn.Exec(context.Background(), q) + return +} + func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err error) { q := fmt.Sprintf(`copy (%s) to stdout`, query) var buff bytes.Buffer @@ -53,10 +86,12 @@ func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err err if err != nil { return } + defer c.Release() if _, err = c.Conn().PgConn().CopyTo(context.Background(), &buff, q); err != nil { return } result = buff.String() + return } @@ -73,9 +108,20 @@ func CopyStringToTable(table string, data string, conn *pgxpool.Pool) (err error if err != nil { return } + defer c.Release() + if _, err = c.Conn().PgConn().CopyFrom(context.Background(), &buff, q); err != nil { return } return } + +func CountRows(s string, conn *pgxpool.Pool) (count int, err error) { + q := "SELECT count(*) FROM " + s + err = conn.QueryRow(context.Background(), q).Scan(&count) + if err != nil { + return + } + return +} diff --git a/subsetter/query_test.go b/subsetter/query_test.go index ff067f7..437137a 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -1,7 +1,6 @@ package subsetter import ( - "context" "strings" "testing" @@ -89,7 +88,7 @@ func TestCopyStringToTable(t *testing.T) { t.Errorf("CopyStringToTable() error = %v, wantErr %v", err, tt.wantErr) return } - gotInserted := insertedRows(tt.table, tt.conn) + gotInserted, _ := CountRows(tt.table, tt.conn) if tt.wantResult != gotInserted { t.Errorf("CopyStringToTable() = %v, want %v", tt.wantResult, tt.wantResult) } @@ -98,12 +97,29 @@ func TestCopyStringToTable(t *testing.T) { } } -func insertedRows(s string, conn *pgxpool.Pool) int { - q := "SELECT count(*) FROM " + s - var count int - err := conn.QueryRow(context.Background(), q).Scan(&count) - if err != nil { - panic(err) +func TestDeleteRows(t *testing.T) { + + conn := getTestConnection() + initSchema(conn) + defer clearSchema(conn) + tests := []struct { + name string + conn *pgxpool.Pool + table string + where string + count int + wantErr bool + }{ + {"With tables", conn, "simple", "1 = 1", 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DeleteRows(tt.table, tt.where, tt.conn); (err != nil) != tt.wantErr { + t.Errorf("DeleteRows() error = %v, wantErr %v", err, tt.wantErr) + } + if gotCount, _ := CountRows(tt.table, tt.conn); gotCount != tt.count { + t.Errorf("DeleteRows() = %v, want %v", gotCount, tt.count) + } + }) } - return count } diff --git a/subsetter/relations.go b/subsetter/relations.go index 3d027d7..0f89813 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -3,8 +3,12 @@ package subsetter import ( "context" "fmt" + "regexp" + "strconv" + "strings" "github.com/jackc/pgx/v5/pgxpool" + "github.com/samber/lo" ) type Relation struct { @@ -14,44 +18,75 @@ type Relation struct { ForeignColumn string } -func (r *Relation) Query() string { - return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (SELECT %s FROM %s)`, r.ForeignTable, r.ForeignColumn, r.PrimaryColumn, r.PrimaryTable) +func (r *Relation) Query(subset []string) string { + + subset = lo.Map(subset, func(s string, _ int) string { + + // if string is a number, don't quote it + if _, err := strconv.Atoi(s); err == nil { + return s + } + return fmt.Sprintf(`'%s'`, s) + }) + + return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (%s)`, r.PrimaryTable, r.PrimaryColumn, strings.Join(subset, ",")) +} + +func (r *Relation) PrimaryQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s`, r.ForeignColumn, r.ForeignTable) +} + +type RelationInfo struct { + TableName string + ForeignTable string + SQL string +} + +func (r *RelationInfo) toRelation() Relation { + var rel Relation + re := regexp.MustCompile(`FOREIGN KEY \((\w+)\) REFERENCES (\w+)\((\w+)\).*`) + matches := re.FindStringSubmatch(r.SQL) + if len(matches) == 4 { + rel.PrimaryColumn = matches[1] + rel.ForeignTable = matches[2] + rel.ForeignColumn = matches[3] + } + rel.PrimaryTable = r.TableName + return rel } // GetRelations returns a list of tables that have a foreign key for particular table. func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation, err error) { q := `SELECT - kcu.table_name AS foreign_table_name, - kcu.column_name AS foreign_column_name, - ccu.table_name, - ccu.column_name + conrelid::regclass AS table_name, + confrelid::regclass AS refrerenced_table, + pg_get_constraintdef(c.oid, TRUE) AS sql FROM - information_schema.table_constraints tc - JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - JOIN information_schema.referential_constraints rc ON tc.constraint_name = rc.constraint_name - AND tc.table_schema = rc.constraint_schema - JOIN information_schema.constraint_column_usage ccu ON rc.unique_constraint_name = ccu.constraint_name + pg_constraint c + JOIN pg_namespace n ON n.oid = c.connamespace WHERE - tc.constraint_type = 'FOREIGN KEY' - AND ccu.table_name = $1 - AND tc.table_schema = 'public';` + c.contype = 'f' + AND n.nspname = 'public';` - rows, err := conn.Query(context.Background(), q, table) + rows, err := conn.Query(context.Background(), q) if err != nil { return } defer rows.Close() for rows.Next() { - var rel Relation - err = rows.Scan(&rel.ForeignTable, &rel.ForeignColumn, &rel.PrimaryTable, &rel.PrimaryColumn) + var rel RelationInfo + + err = rows.Scan(&rel.TableName, &rel.ForeignTable, &rel.SQL) if err != nil { return } - relations = append(relations, rel) + relations = append(relations, rel.toRelation()) } + relations = lo.Filter(relations, func(rel Relation, _ int) bool { + return rel.ForeignTable == table + }) return } diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index b1471c4..edb3534 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -4,6 +4,7 @@ import ( "reflect" "testing" + "github.com/davecgh/go-spew/spew" "github.com/jackc/pgx/v5/pgxpool" ) @@ -17,7 +18,7 @@ func TestGetRelations(t *testing.T) { conn *pgxpool.Pool wantRelations []Relation }{ - {"With relation", "simple", conn, []Relation{{"simple", "id", "relation", "simple_id"}}}, + {"With relation", "simple", conn, []Relation{{"relation", "simple_id", "simple", "id"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -34,13 +35,45 @@ func TestRelation_Query(t *testing.T) { r Relation want string }{ - {"Simple", Relation{"simple", "id", "relation", "simple_id"}, "SELECT * FROM relation WHERE simple_id IN (SELECT id FROM simple)"}, + {"Simple", Relation{"simple", "id", "relation", "simple_id"}, "SELECT * FROM simple WHERE id IN (1)"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := tt.r.Query(); got != tt.want { + if got := tt.r.Query([]string{"1"}); got != tt.want { t.Errorf("Relation.Query() = %v, want %v", got, tt.want) } }) } } + +func TestRelationInfo_toRelation(t *testing.T) { + + tests := []struct { + name string + fields RelationInfo + want Relation + }{ + { + "Simple", + RelationInfo{"relation", "simple", "FOREIGN KEY (simple_id) REFERENCES simple(id)"}, + Relation{"relation", "simple_id", "simple", "id"}, + }, + { + "Simple with cascade", + RelationInfo{"relation", "simple", "FOREIGN KEY (simple_id) REFERENCES simple(id) ON DELETE CASCADE"}, + Relation{"relation", "simple_id", "simple", "id"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &RelationInfo{ + TableName: tt.fields.TableName, + ForeignTable: tt.fields.ForeignTable, + SQL: tt.fields.SQL, + } + if got := r.toRelation(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("RelationInfo.toRelation() = %v, want %v", spew.Sdump(got), spew.Sdump(tt.want)) + } + }) + } +} diff --git a/subsetter/sync.go b/subsetter/sync.go index 1eb95d8..570c4f1 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -2,8 +2,11 @@ package subsetter import ( "context" + "sort" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" + "github.com/pkg/errors" "github.com/rs/zerolog/log" "github.com/samber/lo" ) @@ -13,15 +16,16 @@ type Sync struct { destination *pgxpool.Pool fraction float64 verbose bool - force []Force + include []Rule + exclude []Rule } -type Force struct { +type Rule struct { Table string Where string } -func NewSync(source string, target string, fraction float64, force []Force, verbose bool) (*Sync, error) { +func NewSync(source string, target string, fraction float64, include []Rule, exclude []Rule, verbose bool) (*Sync, error) { src, err := pgxpool.New(context.Background(), source) if err != nil { return nil, err @@ -32,7 +36,7 @@ func NewSync(source string, target string, fraction float64, force []Force, verb return nil, err } - dst, err := pgxpool.New(context.Background(), source) + dst, err := pgxpool.New(context.Background(), target) if err != nil { return nil, err } @@ -46,7 +50,8 @@ func NewSync(source string, target string, fraction float64, force []Force, verb destination: dst, fraction: fraction, verbose: verbose, - force: force, + include: include, + exclude: exclude, }, nil } @@ -60,9 +65,11 @@ func (s *Sync) Close() { func copyTableData(table Table, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { var data string if data, err = CopyTableToString(table.Name, table.Rows, source); err != nil { + log.Error().Err(err).Msgf("Error copying table %s", table.Name) return } if err = CopyStringToTable(table.Name, data, destination); err != nil { + log.Error().Err(err).Msgf("Error pasting table %s", table.Name) return } return @@ -74,51 +81,86 @@ func ViableSubset(tables []Table) (subset []Table) { // Filter out tables with no rows subset = lo.Filter(tables, func(table Table, _ int) bool { return table.Rows > 0 }) - // Get all relations - relationsR := lo.FlatMap(subset, func(table Table, _ int) []Relation { return table.Relations }) - relations := lo.Map(relationsR, func(relation Relation, _ int) string { return relation.ForeignTable }) + // Ignore tables with relations to tables + // they are populated by the primary table + tablesWithRelations := lo.Filter(tables, func(table Table, _ int) bool { + return len(table.Relations) > 0 + }) + + var excludedTables []string + for _, table := range tablesWithRelations { + for _, relation := range table.Relations { + if table.Name != relation.PrimaryTable { + excludedTables = append(excludedTables, relation.PrimaryTable) + } + } + } - // Filter out tables that are relations of other tables - // they will be copied later subset = lo.Filter(subset, func(table Table, _ int) bool { - return !lo.Contains(relations, table.Name) + return !lo.Contains(excludedTables, table.Name) }) + sort.Slice(subset, func(i, j int) bool { + return len(subset[i].Relations) < len(subset[j].Relations) + }) return } // CopyTables copies the data from a list of tables in the source database to the destination database func (s *Sync) CopyTables(tables []Table) (err error) { + for _, table := range tables { - log.Info().Msgf("Copying table %s", table.Name) + if err = copyTableData(table, s.source, s.destination); err != nil { - return + return errors.Wrapf(err, "Error copying table %s", table.Name) } - for _, force := range s.force { - if force.Table == table.Name { - log.Info().Msgf("Selecting forced rows for table %s", table.Name) + for _, include := range s.include { + if include.Table == table.Name { + log.Info().Str("query", include.Where).Msgf("Selecting forced rows for table %s", table.Name) var data string - if data, err = CopyQueryToString(force.Where, s.source); err != nil { - return + if data, err = CopyQueryToString(include.Where, s.source); err != nil { + return errors.Wrapf(err, "Error copying forced rows for table %s", table.Name) } if err = CopyStringToTable(table.Name, data, s.destination); err != nil { - return + return errors.Wrapf(err, "Error inserting forced rows for table %s", table.Name) } } } + for _, exclude := range s.exclude { + if exclude.Table == table.Name { + log.Info().Str("query", exclude.Where).Msgf("Deleting excluded rows for table %s", table.Name) + if err = DeleteRows(exclude.Table, exclude.Where, s.destination); err != nil { + return errors.Wrapf(err, "Error deleting excluded rows for table %s", table.Name) + } + } + } + + count, _ := CountRows(table.Name, s.destination) + log.Info().Int("count", count).Msgf("Copied table %s", table.Name) + for _, relation := range table.Relations { // Backtrace the inserted ids from main table to related table - - log.Info().Msgf("Copying relation %s for table %s", relation, table.Name) + var pKeys []string + if pKeys, err = GetKeys(relation.PrimaryQuery(), s.destination); err != nil { + return errors.Wrapf(err, "Error getting primary keys for %s", relation.PrimaryTable) + } var data string - if data, err = CopyQueryToString(relation.Query(), s.source); err != nil { - return + if data, err = CopyQueryToString(relation.Query(pKeys), s.source); err != nil { + return errors.Wrapf(err, "Error copying related table %s", relation.PrimaryTable) } - if err = CopyStringToTable(table.Name, data, s.destination); err != nil { - return + if err = CopyStringToTable(relation.PrimaryTable, data, s.destination); err != nil { + if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23503" { + log.Warn().Msgf("Skipping %s because of foreign key violation", relation.PrimaryTable) + err = nil + } else { + return errors.Wrapf(err, "Error inserting related table %s", relation.PrimaryTable) + } + } + count, _ := CountRows(relation.PrimaryTable, s.destination) + log.Info().Int("count", count).Msgf("Copied %s for %s", relation.PrimaryTable, table.Name) } } return @@ -140,8 +182,11 @@ func (s *Sync) Sync() (err error) { if s.verbose { for _, t := range subset { - log.Info().Msgf("Copying table %s with %d rows", t.Name, t.Rows) - log.Info().Msgf("Relations: %v", t.Relations) + log.Info(). + Str("table", t.Name). + Int("rows", t.Rows). + Str("related", t.RelationNames()). + Msg("Prepared for sync") } } diff --git a/subsetter/sync_test.go b/subsetter/sync_test.go index 67e85a3..b3f9c8f 100644 --- a/subsetter/sync_test.go +++ b/subsetter/sync_test.go @@ -1,17 +1,10 @@ package subsetter import ( - "os" "reflect" "testing" ) -func skipCI(t *testing.T) { - if os.Getenv("CI") != "" { - t.Skip("Skipping testing in CI environment") - } -} - func TestViableSubset(t *testing.T) { tests := []struct { name string @@ -27,11 +20,6 @@ func TestViableSubset(t *testing.T) { "No rows", []Table{{"simple", 0, []Relation{}}}, []Table{}}, - { - "Complex, related tables must be excluded", - []Table{{"simple", 10, []Relation{}}, {"complex", 10, []Relation{{"simple", "id", "complex", "simple_id"}}}}, - []Table{{"simple", 10, []Relation{}}}, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -43,7 +31,6 @@ func TestViableSubset(t *testing.T) { } func TestSync_CopyTables(t *testing.T) { - skipCI(t) src := getTestConnection() dst := getTestConnectionDst() initSchema(src) @@ -51,6 +38,8 @@ func TestSync_CopyTables(t *testing.T) { defer clearSchema(src) defer clearSchema(dst) + populateTestsWithData(src, "simple", 1000) + s := &Sync{ source: src, destination: dst,