Skip to content

Commit

Permalink
Merge pull request #27 from udacity/cjfinnell/string-pkey-sorting
Browse files Browse the repository at this point in the history
Fix ordering of text primary keys (PIE-1100)
  • Loading branch information
cjfinnell authored Nov 15, 2023
2 parents 8da955a + 131808e commit 0e38936
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 23 deletions.
7 changes: 7 additions & 0 deletions cmd/pgverify/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var (
aliasesFlag, excludeSchemasFlag, excludeTablesFlag, includeSchemasFlag, includeTablesFlag, includeColumnsFlag, excludeColumnsFlag, testModesFlag *[]string
logLevelFlag, timestampPrecisionFlag *string
bookendLimitFlag, sparseModFlag *int
hashPrimaryKeysFlag *bool
)

func init() {
Expand All @@ -39,6 +40,8 @@ func init() {

bookendLimitFlag = rootCmd.Flags().Int("bookend-limit", pgverify.TestModeBookendDefaultLimit, "only check the first and last N rows (with --tests=bookend)")
sparseModFlag = rootCmd.Flags().Int("sparse-mod", pgverify.TestModeSparseDefaultMod, "only check every Nth row (with --tests=sparse)")

hashPrimaryKeysFlag = rootCmd.Flags().Bool("hash-primary-keys", false, "hash primary key values before comparing them (useful for TEXT primary keys)")
}

var rootCmd = &cobra.Command{
Expand Down Expand Up @@ -68,6 +71,10 @@ var rootCmd = &cobra.Command{
pgverify.WithTimestampPrecision(*timestampPrecisionFlag),
}

if *hashPrimaryKeysFlag {
opts = append(opts, pgverify.WithHashPrimaryKeys())
}

logger := log.New()
logger.SetFormatter(&log.TextFormatter{})
levelInt, err := log.ParseLevel(*logLevelFlag)
Expand Down
15 changes: 15 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ type Config struct {
// SparseMod is used in the sparse test mode to deterministically select a
// subset of rows, approximately 1/mod of the total.
SparseMod int
// HashPrimaryKeys is a flag that determines whether or not to hash the values
// of primary keys before using them for ordering. This is useful when the
// primary keys contain TEXT that is sorted differently between engines.
// May impact performance.
HashPrimaryKeys bool

// Aliases is a list of aliases to use for the target databases in reporting
// output. Is ignored if the number of aliases is not equal to the number of
Expand Down Expand Up @@ -193,3 +198,13 @@ func WithTimestampPrecision(precision string) optionFunc {
c.TimestampPrecision = precision
}
}

// WithHashPrimaryKeys configures the verifier to hash primary keys before
// ordering results. This is useful when the primary keys contain TEXT that is
// sorted differently between engines.
// May impact performance.
func WithHashPrimaryKeys() optionFunc {
return func(c *Config) {
c.HashPrimaryKeys = true
}
}
27 changes: 21 additions & 6 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func calculateRowCount(columnTypes map[string][]string) int {
return rowCount
}

//nolint:maintidx
func TestVerifyData(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test")
Expand Down Expand Up @@ -232,15 +233,20 @@ func TestVerifyData(t *testing.T) {
sort.Strings(keysWithTypes)
sort.Strings(sortedTypes)

tableNames := []string{"testtable1", "testTABLE2", "testtable3"}
createTableQueryBase := fmt.Sprintf("( id INT DEFAULT 0 NOT NULL, zid INT DEFAULT 0 NOT NULL, ignored TIMESTAMP WITH TIME ZONE DEFAULT NOW(), %s);", strings.Join(keysWithTypes, ", "))
tableNames := []string{"testtable1", "testTABLE_multi_col_2", "testtable3", "test_stringkey_table4"}
createTableQueryBase := fmt.Sprintf("( id INT DEFAULT 0 NOT NULL, zid INT DEFAULT 0 NOT NULL, sid TEXT NOT NULL, ignored TIMESTAMP WITH TIME ZONE DEFAULT NOW(), %s);", strings.Join(keysWithTypes, ", "))

rowCount := calculateRowCount(columnTypes)
insertDataQueryBase := `(id, zid,` + strings.Join(keys, ", ") + `) VALUES `
insertDataQueryBase := `(id, zid, sid,` + strings.Join(keys, ", ") + `) VALUES `
valueClauses := make([]string, 0, rowCount)

// Modulo-cycle through prefixes to re-create ORDER BY issue
textPKeyPrefixes := []string{"A", "AA", "a", "aa", "A-A", "a-a"}

for rowID := 0; rowID < rowCount; rowID++ {
valueClause := `(` + strconv.Itoa(rowID) + `, 0`
textPKeyPrefix := textPKeyPrefixes[rowID%len(textPKeyPrefixes)]
valueClause := fmt.Sprintf("( %d, 0, '%s-%d'", rowID, textPKeyPrefix, rowID)

for _, columnType := range sortedTypes {
valueClause += `, ` + columnTypes[columnType][rowID%len(columnTypes[columnType])]
}
Expand Down Expand Up @@ -274,11 +280,19 @@ func TestVerifyData(t *testing.T) {
_, err = conn.Exec(ctx, createTableQuery)
require.NoError(t, err, "Failed to create table %s on %v with query: %s", tableName, db.image, createTableQuery)

pkeyString := fmt.Sprintf("single_col_pkey_%s PRIMARY KEY (id)", tableName)
if tableName == tableNames[1] {
var pkeyString string

switch {
case strings.Contains(tableName, "multi_col"):
pkeyString = fmt.Sprintf("multi_col_pkey_%s PRIMARY KEY (id, zid)", tableName)
case strings.Contains(tableName, "stringkey"):
pkeyString = fmt.Sprintf("text_col_pkey_%s PRIMARY KEY (sid)", tableName)
default:
pkeyString = fmt.Sprintf("single_col_pkey_%s PRIMARY KEY (id)", tableName)
}

require.NotEmpty(t, pkeyString)

alterTableQuery := fmt.Sprintf(`ALTER TABLE ONLY "%s" ADD CONSTRAINT %s;`, tableName, pkeyString)
_, err = conn.Exec(ctx, alterTableQuery)
require.NoError(t, err, "Failed to add primary key to table %s on %v with query %s", tableName, db.image, alterTableQuery)
Expand Down Expand Up @@ -310,6 +324,7 @@ func TestVerifyData(t *testing.T) {
pgverify.ExcludeColumns("ignored", "rowid"),
pgverify.WithAliases(aliases),
pgverify.WithBookendLimit(5),
pgverify.WithHashPrimaryKeys(),
)
require.NoError(t, err)
results.WriteAsTable(os.Stdout)
Expand Down
32 changes: 25 additions & 7 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,17 @@ func buildFullHashQuery(config Config, schemaName, tableName string, columns []c

primaryColumnString := strings.Join(primaryKeyNamesWithCasting, ", ")

primaryColumnConcatString := fmt.Sprintf("CONCAT(%s)", primaryColumnString)

if config.HashPrimaryKeys {
primaryColumnConcatString = fmt.Sprintf("MD5(%s)", primaryColumnConcatString)
}

return formatQuery(fmt.Sprintf(`
SELECT md5(string_agg(hash, ''))
FROM (SELECT '' AS grouper, MD5(CONCAT(%s)) AS hash, CONCAT(%s) as primary_key FROM "%s"."%s") AS eachrow
FROM (SELECT '' AS grouper, MD5(CONCAT(%s)) AS hash, %s as primary_key FROM "%s"."%s") AS eachrow
GROUP BY grouper, primary_key ORDER BY primary_key
`, strings.Join(columnsWithCasting, ", "), primaryColumnString, schemaName, tableName))
`, strings.Join(columnsWithCasting, ", "), primaryColumnConcatString, schemaName, tableName))
}

// Similar to the full test query, this test differs by first selecting a subset
Expand Down Expand Up @@ -172,18 +178,24 @@ func buildSparseHashQuery(config Config, schemaName, tableName string, columns [

primaryColumnString := strings.Join(primaryKeyNamesWithCasting, ", ")

primaryColumnConcatString := fmt.Sprintf("CONCAT(%s)", primaryColumnString)

if config.HashPrimaryKeys {
primaryColumnConcatString = fmt.Sprintf("MD5(%s)", primaryColumnConcatString)
}

return formatQuery(fmt.Sprintf(`
SELECT md5(string_agg(hash, ''))
FROM (
SELECT '' AS grouper, MD5(CONCAT(%s)) AS hash, CONCAT(%s) as primary_key
SELECT '' AS grouper, MD5(CONCAT(%s)) AS hash, %s as primary_key
FROM "%s"."%s"
WHERE %s
ORDER BY CONCAT(%s)
) AS eachrow
GROUP BY grouper, primary_key
ORDER BY primary_key
`,
strings.Join(columnsWithCasting, ", "), primaryColumnString,
strings.Join(columnsWithCasting, ", "), primaryColumnConcatString,
schemaName, tableName, whenClausesString,
primaryKeyNamesWithCastingString))
}
Expand All @@ -209,14 +221,20 @@ func buildBookendHashQuery(config Config, schemaName, tableName string, columns
allColumnsWithCasting := strings.Join(columnsWithCasting, ", ")
allPrimaryColumnsWithCasting := strings.Join(primaryKeyNamesWithCasting, ", ")

allPrimaryColumnsConcatString := fmt.Sprintf("CONCAT(%s)", allPrimaryColumnsWithCasting)

if config.HashPrimaryKeys {
allPrimaryColumnsConcatString = fmt.Sprintf("MD5(%s)", allPrimaryColumnsConcatString)
}

return formatQuery(fmt.Sprintf(`
SELECT md5(CONCAT(starthash::TEXT, endhash::TEXT))
FROM (
SELECT md5(string_agg(hash, ''))
FROM (
SELECT '' AS grouper, MD5(CONCAT(%s)) AS hash
FROM "%s"."%s"
ORDER BY CONCAT(%s) ASC
ORDER BY %s ASC
LIMIT %d
) AS eachrow
GROUP BY grouper
Expand All @@ -225,12 +243,12 @@ func buildBookendHashQuery(config Config, schemaName, tableName string, columns
FROM (
SELECT '' AS grouper, MD5(CONCAT(%s)) AS hash
FROM "%s"."%s"
ORDER BY CONCAT(%s) DESC
ORDER BY %s DESC
LIMIT %d
) AS eachrow
GROUP BY grouper
) as endhash
`, allColumnsWithCasting, schemaName, tableName, allPrimaryColumnsWithCasting, limit, allColumnsWithCasting, schemaName, tableName, allPrimaryColumnsWithCasting, limit))
`, allColumnsWithCasting, schemaName, tableName, allPrimaryColumnsConcatString, limit, allColumnsWithCasting, schemaName, tableName, allPrimaryColumnsConcatString, limit))
}

// A minimal test that simply counts the number of rows.
Expand Down
61 changes: 51 additions & 10 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ func TestBuildFullHashQuery(t *testing.T) {
(SELECT '' AS grouper, MD5(CONCAT((extract(epoch from date_trunc('milliseconds', when))::DECIMAL * 1000000)::BIGINT::TEXT, content::TEXT, id::TEXT)) AS hash, CONCAT(content::TEXT, id::TEXT) as primary_key
FROM "testSchema"."testTable") AS eachrow GROUP BY grouper, primary_key ORDER BY primary_key`),
},
{
name: "multi-column hashed primary key",
config: Config{TimestampPrecision: TimestampPrecisionMilliseconds, HashPrimaryKeys: true},
schemaName: "testSchema",
tableName: "testTable",
columns: []column{
{name: "id", dataType: "uuid", constraints: []string{"PRIMARY KEY", "another constraint"}},
{name: "content", dataType: "text", constraints: []string{"PRIMARY KEY"}},
{name: "when", dataType: "timestamp with time zone"},
},
primaryColumnNamesString: "id, content",
expectedQuery: formatQuery(`
SELECT md5(string_agg(hash, ''))
FROM
(SELECT '' AS grouper, MD5(CONCAT((extract(epoch from date_trunc('milliseconds', when))::DECIMAL * 1000000)::BIGINT::TEXT, content::TEXT, id::TEXT)) AS hash, MD5(CONCAT(content::TEXT, id::TEXT)) as primary_key
FROM "testSchema"."testTable") AS eachrow GROUP BY grouper, primary_key ORDER BY primary_key`),
},
} {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expectedQuery, buildFullHashQuery(tc.config, tc.schemaName, tc.tableName, tc.columns))
Expand Down Expand Up @@ -106,12 +123,12 @@ func TestBuildSparseHashQuery(t *testing.T) {
SELECT md5(string_agg(hash, ''))
FROM
( SELECT '' AS grouper, MD5(CONCAT((extract(epoch from date_trunc('milliseconds', when))::DECIMAL * 1000000)::BIGINT::TEXT, content::TEXT, id::TEXT)) AS hash, CONCAT(id::TEXT) as primary_key
FROM "testSchema"."testTable"
WHERE id in (
SELECT id FROM "testSchema"."testTable"
WHERE ('x' || substr(md5(CONCAT(id::TEXT)),1,16))::bit(64)::bigint % 10 = 0 )
FROM "testSchema"."testTable"
WHERE id in (
SELECT id FROM "testSchema"."testTable"
WHERE ('x' || substr(md5(CONCAT(id::TEXT)),1,16))::bit(64)::bigint % 10 = 0 )
ORDER BY CONCAT(id::TEXT)
)
)
AS eachrow GROUP BY grouper, primary_key ORDER BY primary_key`),
},
{
Expand All @@ -128,12 +145,36 @@ func TestBuildSparseHashQuery(t *testing.T) {
SELECT md5(string_agg(hash, ''))
FROM
( SELECT '' AS grouper, MD5(CONCAT((extract(epoch from date_trunc('milliseconds', when))::DECIMAL * 1000000)::BIGINT::TEXT, content::TEXT, id::TEXT)) AS hash, CONCAT(content::TEXT, id::TEXT) as primary_key
FROM "testSchema"."testTable"
WHERE content in (
SELECT content FROM "testSchema"."testTable"
FROM "testSchema"."testTable"
WHERE content in (
SELECT content FROM "testSchema"."testTable"
WHERE ('x' || substr(md5(CONCAT(content::TEXT, id::TEXT)),1,16))::bit(64)::bigint % 10 = 0
) AND id in (
SELECT id FROM "testSchema"."testTable"
WHERE ('x' || substr(md5(CONCAT(content::TEXT, id::TEXT)),1,16))::bit(64)::bigint % 10 = 0
) ORDER BY CONCAT(content::TEXT, id::TEXT) )
AS eachrow GROUP BY grouper, primary_key ORDER BY primary_key`),
},
{
name: "multi-column hashed primary key",
config: Config{TimestampPrecision: TimestampPrecisionMilliseconds, HashPrimaryKeys: true},
schemaName: "testSchema",
tableName: "testTable",
columns: []column{
{name: "id", dataType: "uuid", constraints: []string{"PRIMARY KEY", "another constraint"}},
{name: "content", dataType: "text", constraints: []string{"PRIMARY KEY"}},
{name: "when", dataType: "timestamp with time zone"},
},
expectedQuery: formatQuery(`
SELECT md5(string_agg(hash, ''))
FROM
( SELECT '' AS grouper, MD5(CONCAT((extract(epoch from date_trunc('milliseconds', when))::DECIMAL * 1000000)::BIGINT::TEXT, content::TEXT, id::TEXT)) AS hash, MD5(CONCAT(content::TEXT, id::TEXT)) as primary_key
FROM "testSchema"."testTable"
WHERE content in (
SELECT content FROM "testSchema"."testTable"
WHERE ('x' || substr(md5(CONCAT(content::TEXT, id::TEXT)),1,16))::bit(64)::bigint % 10 = 0
) AND id in (
SELECT id FROM "testSchema"."testTable"
) AND id in (
SELECT id FROM "testSchema"."testTable"
WHERE ('x' || substr(md5(CONCAT(content::TEXT, id::TEXT)),1,16))::bit(64)::bigint % 10 = 0
) ORDER BY CONCAT(content::TEXT, id::TEXT) )
AS eachrow GROUP BY grouper, primary_key ORDER BY primary_key`),
Expand Down

0 comments on commit 0e38936

Please sign in to comment.