diff --git a/.vscode/settings.json b/.vscode/settings.json index b176ff8..eb56df2 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,7 +5,6 @@ "spellright.documentTypes": [ "markdown", "latex", - "plaintext", - "go" + "plaintext" ] } \ No newline at end of file diff --git a/subsetter/info.go b/subsetter/info.go index 6ee385f..2517398 100644 --- a/subsetter/info.go +++ b/subsetter/info.go @@ -1,6 +1,10 @@ package subsetter -import "math" +import ( + "fmt" + "math" + "strconv" +) // GetTargetSet returns a subset of tables with the number of rows scaled by the fraction. func GetTargetSet(fraction float64, tables []Table) []Table { @@ -16,3 +20,11 @@ func GetTargetSet(fraction float64, tables []Table) []Table { return subset } + +func QuoteString(s string) string { + // if string is a number, don't quote it + if _, err := strconv.Atoi(s); err == nil { + return s + } + return fmt.Sprintf(`'%s'`, s) +} diff --git a/subsetter/query.go b/subsetter/query.go index 44148e1..1061443 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -73,6 +73,23 @@ func GetKeys(q string, conn *pgxpool.Pool) (ids []string, err error) { return } +func GetPrimaryKeyName(table string, conn *pgxpool.Pool) (name string, err error) { + q := fmt.Sprintf(`SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = '%s'::regclass + AND i.indisprimary;`, table) + rows, err := conn.Query(context.Background(), q) + for rows.Next() { + if err := rows.Scan(&name); err != nil { + return "", err + } + } + rows.Close() + return +} + func DeleteRows(table string, where string, conn *pgxpool.Pool) (err error) { q := fmt.Sprintf(`DELETE FROM %s WHERE %s`, table, where) _, err = conn.Exec(context.Background(), q) @@ -95,8 +112,8 @@ func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err err return } -func CopyTableToString(table string, limit int, conn *pgxpool.Pool) (result string, err error) { - q := fmt.Sprintf(`SELECT * FROM %s order by random() limit %d`, table, limit) +func CopyTableToString(table string, limit int, where string, conn *pgxpool.Pool) (result string, err error) { + q := fmt.Sprintf(`SELECT * FROM %s %s order by random() limit %d`, table, where, limit) return CopyQueryToString(q, conn) } diff --git a/subsetter/query_test.go b/subsetter/query_test.go index 437137a..6bad69d 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -53,7 +53,7 @@ func TestCopyTableToString(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotResult, err := CopyTableToString(tt.table, 10, tt.conn) + gotResult, err := CopyTableToString(tt.table, 10, "", tt.conn) if (err != nil) != tt.wantErr { t.Errorf("CopyTableToString() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/subsetter/relations.go b/subsetter/relations.go index 0f89813..71dc9ff 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "regexp" - "strconv" "strings" "github.com/jackc/pgx/v5/pgxpool" @@ -21,12 +20,7 @@ type Relation struct { func (r *Relation) Query(subset []string) string { subset = lo.Map(subset, func(s string, _ int) string { - - // if string is a number, don't quote it - if _, err := strconv.Atoi(s); err == nil { - return s - } - return fmt.Sprintf(`'%s'`, s) + return QuoteString(s) }) return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (%s)`, r.PrimaryTable, r.PrimaryColumn, strings.Join(subset, ",")) diff --git a/subsetter/sync.go b/subsetter/sync.go index 2130e13..a148d08 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sort" + "strings" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" @@ -26,7 +27,11 @@ type Rule struct { Where string } -func (r Rule) Query() string { +func (r *Rule) String() string { + return fmt.Sprintf("%s:%s", r.Table, r.Where) +} + +func (r *Rule) Query() string { if r.Where == "" { return fmt.Sprintf("SELECT * FROM %s", r.Table) } @@ -71,8 +76,28 @@ func (s *Sync) Close() { // copyTableData copies the data from a table in the source database to the destination database func copyTableData(table Table, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { + // Backtrace the inserted ids from main table to related table + + // Get primary keys + primaryKeyName, err := GetPrimaryKeyName(table.Name, source) + if err != nil { + return errors.Wrapf(err, "Error getting primary key for %s", table.Name) + } + + var ignoredPrimaryKeys []string + if ignoredPrimaryKeys, err = GetKeys(fmt.Sprintf("SELECT %s FROM %s", primaryKeyName, table.Name), destination); err != nil { + return errors.Wrapf(err, "Error getting primary keys for %s", table.Name) + } + ignoredPrimaryQuery := "" + if len(ignoredPrimaryKeys) > 0 { + keys := lo.Map(ignoredPrimaryKeys, func(key string, _ int) string { + return QuoteString(key) + }) + ignoredPrimaryQuery = fmt.Sprintf("WHERE %s NOT IN (%s)", primaryKeyName, strings.Join(keys, ",")) + } + var data string - if data, err = CopyTableToString(table.Name, table.Rows, source); err != nil { + if data, err = CopyTableToString(table.Name, table.Rows, ignoredPrimaryQuery, source); err != nil { log.Error().Err(err).Msgf("Error copying table %s", table.Name) return } @@ -95,17 +120,17 @@ func ViableSubset(tables []Table) (subset []Table) { return len(table.Relations) > 0 }) - var excludedTables []string + var relatedTables []string for _, table := range tablesWithRelations { for _, relation := range table.Relations { if table.Name != relation.PrimaryTable { - excludedTables = append(excludedTables, relation.PrimaryTable) + relatedTables = append(relatedTables, relation.PrimaryTable) } } } subset = lo.Filter(subset, func(table Table, _ int) bool { - return !lo.Contains(excludedTables, table.Name) + return !lo.Contains(relatedTables, table.Name) }) sort.Slice(subset, func(i, j int) bool { @@ -117,15 +142,19 @@ func ViableSubset(tables []Table) (subset []Table) { // CopyTables copies the data from a list of tables in the source database to the destination database func (s *Sync) CopyTables(tables []Table) (err error) { - for _, table := range tables { + excludedTables := lo.Map(s.exclude, func(rule Rule, _ int) string { + return rule.Table + }) - if err = copyTableData(table, s.source, s.destination); err != nil { - return errors.Wrapf(err, "Error copying table %s", table.Name) - } + tables = lo.Filter(tables, func(table Table, _ int) bool { + return !lo.Contains(excludedTables, table.Name) + }) + + for _, table := range tables { for _, include := range s.include { if include.Table == table.Name { - log.Info().Str("query", include.Where).Msgf("Selecting forced rows for table %s", table.Name) + log.Info().Str("query", include.Where).Msgf("Copying forced rows for table %s", table.Name) var data string if data, err = CopyQueryToString(include.Query(), s.source); err != nil { return errors.Wrapf(err, "Error copying forced rows for table %s", table.Name) @@ -135,6 +164,13 @@ func (s *Sync) CopyTables(tables []Table) (err error) { } } } + } + + for _, table := range tables { + log.Info().Msgf("Preparing %s", table.Name) + if err = copyTableData(table, s.source, s.destination); err != nil { + return errors.Wrapf(err, "Error copying table %s", table.Name) + } for _, exclude := range s.exclude { if exclude.Table == table.Name { @@ -149,7 +185,12 @@ func (s *Sync) CopyTables(tables []Table) (err error) { log.Info().Int("count", count).Msgf("Copied table %s", table.Name) for _, relation := range table.Relations { + if lo.Contains(excludedTables, relation.PrimaryTable) { + continue + } + // Backtrace the inserted ids from main table to related table + log.Info().Msgf("Preparing %s for %s", relation.PrimaryTable, table.Name) var pKeys []string if pKeys, err = GetKeys(relation.PrimaryQuery(), s.destination); err != nil { return errors.Wrapf(err, "Error getting primary keys for %s", relation.PrimaryTable) @@ -160,7 +201,10 @@ func (s *Sync) CopyTables(tables []Table) (err error) { } if err = CopyStringToTable(relation.PrimaryTable, data, s.destination); err != nil { if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23503" { - log.Warn().Msgf("Skipping %s because of cyclic foreign key", relation.PrimaryTable) + log.Warn().AnErr("sql", err).Msgf("Skipping %s because of cyclic foreign key", relation.PrimaryTable) + err = nil + } else if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23505" { + log.Warn().AnErr("sql", err).Msgf("Skipping %s because of present foreign key", relation.PrimaryTable) err = nil } else { return errors.Wrapf(err, "Error inserting related table %s", relation.PrimaryTable)