diff --git a/README.md b/README.md index e4506fb..a31411c 100644 --- a/README.md +++ b/README.md @@ -28,16 +28,16 @@ Usage of subsetter: -dst string Destination database DSN -exclude value - Query to ignore tables, can be used multiple times; 'users: id = 123' for a specific user, 'users: 1=1' for all users + Query to ignore tables 'users: all', can be used multiple times -f float Fraction of rows to copy (default 0.05) -include value - Query to copy required tables, can be used multiple times; 'users: id = 123' for a specific user, 'users: 1=1' for all users + Query to copy required rows 'users: id = 1', can be used multiple times -src string Source database DSN -v Release information -verbose - Show more information during sync (default true) + Show more information during sync ``` @@ -60,7 +60,7 @@ pg_subsetter \ -f 0.5 -include "user: id=1" -include "group: id=1" - -exclude "domains: domain_name ilike '%.si'" + -exclude "domains: all" ``` diff --git a/cli/main.go b/cli/main.go index c8d1cea..5b23c69 100644 --- a/cli/main.go +++ b/cli/main.go @@ -28,8 +28,8 @@ func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack - flag.Var(&extraInclude, "include", "Query to copy required tables 'users: id = 1', can be used multiple times") - flag.Var(&extraExclude, "exclude", "Query to ignore tables 'users: id = 1', can be used multiple times") + flag.Var(&extraInclude, "include", "Query to copy required rows 'users: id = 1', can be used multiple times") + flag.Var(&extraExclude, "exclude", "Query to ignore tables 'users: all', can be used multiple times") flag.Parse() if *ver { diff --git a/subsetter/graph.go b/subsetter/graph.go index f13a515..ceba4bf 100644 --- a/subsetter/graph.go +++ b/subsetter/graph.go @@ -6,15 +6,21 @@ import ( "github.com/stevenle/topsort" ) -func TableGraph(primary string, relations []Relation) (l []string, e error) { +func TableGraph(primary string, relations []Relation) (l []string, err error) { graph := topsort.NewGraph() // Create a new graph for _, r := range relations { if !r.IsSelfRelated() { - graph.AddEdge(r.PrimaryTable, r.ForeignTable) + err = graph.AddEdge(r.PrimaryTable, r.ForeignTable) + if err != nil { + return + } } } - l, e = graph.TopSort(primary) + l, err = graph.TopSort(primary) + if err != nil { + return + } slices.Reverse(l) return } diff --git a/subsetter/graph_test.go b/subsetter/graph_test.go index 7b414ff..b39f346 100644 --- a/subsetter/graph_test.go +++ b/subsetter/graph_test.go @@ -22,8 +22,8 @@ func TestTableGraph(t *testing.T) { got, _ := TableGraph("users", relations) - if want, _ := lo.Last(got); want != "users" { - t.Fatalf("TableGraph() = %v, want %v", got, "users") + if want, _ := lo.Nth(got, 0); want != "users" { + t.Fatalf("TableGraph() = %v, want %v", got, want) } } diff --git a/subsetter/query.go b/subsetter/query.go index 24709fb..6507620 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -36,6 +36,13 @@ func (t *Table) IsSelfRelated() bool { return false } +// IsSelfRelated returns true if a table is self related. +func TableByName(tables []Table, name string) Table { + return lo.Filter(tables, func(table Table, _ int) bool { + return table.Name == name + })[0] +} + // GetTablesWithRows returns a list of tables with the number of rows in each table. // Warning reltuples used to dermine size is an estimate of the number of rows in the table and can be zero for small tables. func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index 4134edb..9eb01e6 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -18,7 +18,7 @@ func TestGetRelations(t *testing.T) { conn *pgxpool.Pool wantRelations []Relation }{ - {"With relation", "simple", conn, []Relation{{"relation", "simple_id", "simple", "id"}}}, + {"With relation", "relation", conn, []Relation{{"relation", "simple_id", "simple", "id"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/subsetter/sync.go b/subsetter/sync.go index 4f509fb..7a31e93 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -12,6 +12,14 @@ import ( "github.com/samber/lo" ) +type SyncError struct { + Retry bool +} + +func (se *SyncError) Error() string { + return fmt.Sprintf("Sync error: retry=%t", se.Retry) +} + type Sync struct { source *pgxpool.Pool destination *pgxpool.Pool @@ -38,7 +46,7 @@ func (r *Rule) Query() string { } func (r *Rule) Copy(s *Sync) (err error) { - log.Debug().Str("query", r.Where).Msgf("Copying forced rows for table %s", r.Table) + log.Debug().Str("query", r.Where).Msgf("Transfering forced rows for table %s", r.Table) var data string if data, err = CopyQueryToString(r.Query(), s.source); err != nil { return errors.Wrapf(err, "Error copying forced rows for table %s", r.Table) @@ -100,11 +108,11 @@ func copyTableData(table Table, relatedQueries []string, withLimit bool, source var data string if data, err = CopyTableToString(table.Name, limit, subselectQeury, source); err != nil { - log.Error().Err(err).Msgf("Error getting table data for %s", table.Name) + //log.Error().Err(err).Str("table", table.Name).Msg("Error getting table data") return } if err = CopyStringToTable(table.Name, data, destination); err != nil { - log.Error().Err(err).Msgf("Error pushing table data for %s", table.Name) + //log.Error().Err(err).Str("table", table.Name).Msg("Error pushing table data") return } return @@ -131,17 +139,20 @@ retry: return err } else { if len(primaryKeys) == 0 { - log.Warn().Int("depth", *depth).Msgf("No keys found for %s", relation.ForeignTable) - missingTable := lo.Filter(tables, func(table Table, _ int) bool { - return table.Name == relation.ForeignTable - })[0] - RelationalCopy(depth, tables, missingTable, visitedTables, source, destination) + + missingTable := TableByName(tables, relation.ForeignTable) + if err = RelationalCopy(depth, tables, missingTable, visitedTables, source, destination); err != nil { + return errors.Wrapf(err, "Error copying table %s", missingTable.Name) + } + + // Retry short circuit *depth++ + log.Debug().Int("depth", *depth).Msgf("Retrying keys for %s", relation.ForeignTable) if *depth < 1 { goto retry } else { - log.Warn().Int("depth", *depth).Msgf("Max depth reached for %s", relation.ForeignTable) + log.Warn().Str("table", relation.ForeignTable).Str("primary", relation.PrimaryTable).Msgf("No keys found at this time") return errors.New("Max depth reached") } @@ -178,9 +189,8 @@ func RelationalCopy( if lo.Contains(*visitedTables, tableName) { continue } - relatedTable := lo.Filter(tables, func(table Table, _ int) bool { - return table.Name == tableName - })[0] + + relatedTable := TableByName(tables, tableName) *visitedTables = append(*visitedTables, relatedTable.Name) // Use realized query to get priamry keys that are already in the destination for all related tables @@ -188,17 +198,23 @@ func RelationalCopy( relatedQueries := []string{} for _, relation := range relatedTable.Relations { - relatedQueriesBuilder(depth, tables, relation, relatedTable, source, destination, visitedTables, &relatedQueries) + err := relatedQueriesBuilder(depth, tables, relation, relatedTable, source, destination, visitedTables, &relatedQueries) + if err != nil { + return err + } } if len(relatedQueries) > 0 { - log.Debug().Str("table", relatedTable.Name).Strs("relatedQueries", relatedQueries).Msg("Copying with RelationalCopy") + log.Debug().Str("table", relatedTable.Name).Strs("relatedQueries", relatedQueries).Msg("Transfering with RelationalCopy") } if err = copyTableData(relatedTable, relatedQueries, false, source, destination); err != nil { if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23503" { // foreign key violation - RelationalCopy(depth, tables, relatedTable, visitedTables, source, destination) + if err := RelationalCopy(depth, tables, relatedTable, visitedTables, source, destination); err != nil { + return errors.Wrapf(err, "Error copying table %s", relatedTable.Name) + } } + return errors.Wrapf(err, "Error copying table %s", relatedTable.Name) } } @@ -214,14 +230,17 @@ func (s *Sync) CopyTables(tables []Table) (err error) { for _, table := range lo.Filter(tables, func(table Table, _ int) bool { return len(table.Relations) == 0 }) { - log.Info().Str("table", table.Name).Msg("Copying") + log.Info().Str("table", table.Name).Msg("Transfering") if err = copyTableData(table, []string{}, true, s.source, s.destination); err != nil { return errors.Wrapf(err, "Error copying table %s", table.Name) } for _, include := range s.include { if include.Table == table.Name { - include.Copy(s) + err = include.Copy(s) + if err != nil { + return errors.Wrapf(err, "Error copying forced rows for table %s", table.Name) + } } } @@ -231,20 +250,36 @@ func (s *Sync) CopyTables(tables []Table) (err error) { // Prevent infinite loop, by setting max depth depth := 0 // Copy tables with relations + maybeRetry := []Table{} + for _, complexTable := range lo.Filter(tables, func(table Table, _ int) bool { return len(table.Relations) > 0 }) { - log.Info().Str("table", complexTable.Name).Msg("Copying") - RelationalCopy(&depth, tables, complexTable, &visitedTables, s.source, s.destination) + log.Info().Str("table", complexTable.Name).Msg("Transfering") + if err := RelationalCopy(&depth, tables, complexTable, &visitedTables, s.source, s.destination); err != nil { + log.Info().Str("table", complexTable.Name).Msgf("Transfering failed, retrying later") + maybeRetry = append(maybeRetry, complexTable) + } for _, include := range s.include { if include.Table == complexTable.Name { - log.Warn().Str("table", complexTable.Name).Msgf("Copying forced rows for relational table is not supported.") + log.Warn().Str("table", complexTable.Name).Msgf("Transfering forced rows for relational table is not supported.") } } } + // Retry tables with relations + visitedRetriedTables := []string{} + for _, retiredTable := range maybeRetry { + log.Info().Str("table", retiredTable.Name).Msg("Transfering") + if err := RelationalCopy(&depth, tables, retiredTable, &visitedRetriedTables, s.source, s.destination); err != nil { + log.Warn().Str("table", retiredTable.Name).Msgf("Transfering failed, try increasing fraction index") + } + } + // Remove excluded rows and print reports + fmt.Println() + fmt.Println("Report:") for _, table := range tables { // to ensure no data is in excluded tables for _, exclude := range s.exclude {