-
Notifications
You must be signed in to change notification settings - Fork 397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add support for constraints in databricks_sql_table resource #4205
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,15 @@ type SqlColumnInfo struct { | |
TypeJson string `json:"type_json,omitempty" tf:"computed"` | ||
} | ||
|
||
type ConstraintInfo struct { | ||
Name string `json:"name"` | ||
Type string `json:"type"` | ||
KeyColumns []string `json:"key_columns"` | ||
ParentTable string `json:"parent_table,omitempty"` | ||
ParentColumns []string `json:"parent_columns,omitempty"` | ||
Rely bool `json:"rely,omitempty" tf:"default:false"` | ||
} | ||
|
||
type TypeJson struct { | ||
Metadata map[string]any `json:"metadata,omitempty"` | ||
} | ||
|
@@ -51,6 +60,7 @@ type SqlTableInfo struct { | |
ColumnInfos []SqlColumnInfo `json:"columns,omitempty" tf:"alias:column,computed"` | ||
Partitions []string `json:"partitions,omitempty" tf:"force_new"` | ||
ClusterKeys []string `json:"cluster_keys,omitempty"` | ||
Constraints []ConstraintInfo `json:"constraints,omitempty" tf:"alias:constraint"` | ||
StorageLocation string `json:"storage_location,omitempty" tf:"suppress_diff"` | ||
StorageCredentialName string `json:"storage_credential_name,omitempty" tf:"force_new"` | ||
ViewDefinition string `json:"view_definition,omitempty"` | ||
|
@@ -89,6 +99,9 @@ func (ti SqlTableInfo) CustomizeSchema(s *common.CustomizableSchema) *common.Cus | |
s.SchemaPath("column", "type").SetCustomSuppressDiff(func(k, old, new string, d *schema.ResourceData) bool { | ||
return getColumnType(old) == getColumnType(new) | ||
}) | ||
s.SchemaPath("constraint", "type").SetCustomSuppressDiff(func(k, old, new string, d *schema.ResourceData) bool { | ||
return getColumnType(old) == getColumnType(new) | ||
}) | ||
return s | ||
} | ||
|
||
|
@@ -242,6 +255,39 @@ func (ti *SqlTableInfo) serializeColumnInfos() string { | |
return strings.Join(columnFragments[:], ", ") // id INT NOT NULL, name STRING, age INT | ||
} | ||
|
||
func serializePrimaryKeyConstraint(constraint ConstraintInfo) string { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make this as part of the ConstraintInfo struct that we do for getWrappedConstraintName as it will only be used for each constraint info and internally uses ConstraintInfo. getWrappedConstraintName, getWrappedKeyColumnNames and Rely? We do the same for SqlTableInfo. serializeColumnInfos, serializeColumnInfo |
||
constraint_clause := fmt.Sprintf("CONSTRAINT %s PRIMARY KEY(%s)", constraint.getWrappedConstraintName(), constraint.getWrappedKeyColumnNames()) | ||
if constraint.Rely { | ||
constraint_clause += " RELY" | ||
} | ||
return constraint_clause | ||
} | ||
|
||
func serializeForeignKeyConstraint(constraint ConstraintInfo) string { | ||
constraint_clause := fmt.Sprintf("CONSTRAINT %s FOREIGN KEY(%s) REFERENCES %s", constraint.getWrappedConstraintName(), constraint.getWrappedKeyColumnNames(), constraint.ParentTable) | ||
if len(constraint.ParentColumns) > 0 { | ||
constraint_clause += fmt.Sprintf("(%s)", constraint.getWrappedParentColumnNames()) | ||
} | ||
if constraint.Rely { | ||
constraint_clause += " RELY" | ||
} | ||
return constraint_clause | ||
} | ||
|
||
func (ti *SqlTableInfo) serializeConstraints() string { | ||
constraintFragments := make([]string, len(ti.Constraints)) | ||
|
||
for i, constraint := range ti.Constraints { | ||
if constraint.Type == "PRIMARY KEY" { | ||
constraintFragments[i] = serializePrimaryKeyConstraint(constraint) | ||
} else if constraint.Type == "FOREIGN KEY" { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we only support |
||
constraintFragments[i] = serializeForeignKeyConstraint(constraint) | ||
} | ||
} | ||
|
||
return strings.Join(constraintFragments[:], ", ") // CONSTRAINT `pk`` PRIMARY KEY (`id`, `nickname`), CONSTRAINT `fk` FOREIGN KEY (`player_id`) REFERENCES players | ||
} | ||
|
||
func (ti *SqlTableInfo) serializeProperties() string { | ||
propsMap := make([]string, 0, len(ti.Properties)) | ||
for key, value := range ti.Properties { | ||
|
@@ -290,7 +336,11 @@ func (ti *SqlTableInfo) buildTableCreateStatement() string { | |
statements = append(statements, fmt.Sprintf("CREATE %s%s %s", externalFragment, createType, ti.SQLFullName())) | ||
|
||
if len(ti.ColumnInfos) > 0 { | ||
statements = append(statements, fmt.Sprintf(" (%s)", ti.serializeColumnInfos())) | ||
columnInfosClause := ti.serializeColumnInfos() | ||
if len(ti.Constraints) > 0 { | ||
columnInfosClause += fmt.Sprintf(", %s", ti.serializeConstraints()) | ||
} | ||
statements = append(statements, fmt.Sprintf(" (%s)", columnInfosClause)) | ||
} | ||
|
||
if !isView { | ||
|
@@ -342,6 +392,21 @@ func (ti *SqlTableInfo) getWrappedClusterKeys() string { | |
return "`" + strings.Join(ti.ClusterKeys, "`,`") + "`" | ||
} | ||
|
||
// Wrapping the constraint name with backticks to avoid special character messing things up. | ||
func (ci ConstraintInfo) getWrappedConstraintName() string { | ||
return fmt.Sprintf("`%s`", ci.Name) | ||
} | ||
|
||
// Wrapping constraint column names with backticks to avoid special character messing things up. | ||
func (ci ConstraintInfo) getWrappedKeyColumnNames() string { | ||
return "`" + strings.Join(ci.KeyColumns, "`,`") + "`" | ||
} | ||
|
||
// Wrapping parent column name with backticks to avoid special character messing things up. | ||
func (ci ConstraintInfo) getWrappedParentColumnNames() string { | ||
return "`" + strings.Join(ci.ParentColumns, "`,`") + "`" | ||
} | ||
|
||
func (ti *SqlTableInfo) getStatementsForColumnDiffs(oldti *SqlTableInfo, statements []string, typestring string) []string { | ||
if len(ti.ColumnInfos) != len(oldti.ColumnInfos) { | ||
statements = ti.addOrRemoveColumnStatements(oldti, statements, typestring) | ||
|
@@ -413,6 +478,43 @@ func (ti *SqlTableInfo) alterExistingColumnStatements(oldti *SqlTableInfo, state | |
return statements | ||
} | ||
|
||
func (ti *SqlTableInfo) addOrRemoveConstraintStatements(oldti *SqlTableInfo, statements []string, typestring string) []string { | ||
nameToOldConstraint := make(map[string]ConstraintInfo) | ||
nameToNewConstraint := make(map[string]ConstraintInfo) | ||
for _, ci := range oldti.Constraints { | ||
nameToOldConstraint[ci.Name] = ci | ||
} | ||
for _, newCi := range ti.Constraints { | ||
nameToNewConstraint[newCi.Name] = newCi | ||
} | ||
|
||
removeConstraintStatements := make([]string, 0) | ||
|
||
for name, oldCi := range nameToOldConstraint { | ||
if _, exists := nameToNewConstraint[name]; !exists { | ||
// Remove old constraint if old constraint is no longer found in the config. | ||
removeConstraintStatements = append(removeConstraintStatements, oldCi.getWrappedConstraintName()) | ||
} | ||
} | ||
for _, removeStatement := range removeConstraintStatements { | ||
statements = append(statements, fmt.Sprintf("ALTER %s %s DROP CONSTRAINT IF EXISTS %s", typestring, ti.SQLFullName(), removeStatement)) | ||
} | ||
|
||
for _, newCi := range ti.Constraints { | ||
if _, exists := nameToOldConstraint[newCi.Name]; !exists { | ||
var newConstraintStatement string | ||
if newCi.Type == "PRIMARY KEY" { | ||
newConstraintStatement = serializePrimaryKeyConstraint(newCi) | ||
} else if newCi.Type == "FOREIGN KEY" { | ||
newConstraintStatement = serializeForeignKeyConstraint(newCi) | ||
} | ||
statements = append(statements, fmt.Sprintf("ALTER %s %s ADD %s", typestring, ti.SQLFullName(), newConstraintStatement)) | ||
} | ||
} | ||
|
||
return statements | ||
} | ||
|
||
func (ti *SqlTableInfo) diff(oldti *SqlTableInfo) ([]string, error) { | ||
statements := make([]string, 0) | ||
typestring := ti.getTableTypeString() | ||
|
@@ -454,6 +556,7 @@ func (ti *SqlTableInfo) diff(oldti *SqlTableInfo) ([]string, error) { | |
} | ||
|
||
statements = ti.getStatementsForColumnDiffs(oldti, statements, typestring) | ||
statements = ti.addOrRemoveConstraintStatements(oldti, statements, typestring) | ||
|
||
return statements, nil | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would work because we are essentially checking for case insensitivity and ignoring the mapping here:
We shouldn't use getColumnType here.