Skip to content

Commit

Permalink
Support Unity Catalog Registered Models in bundles (#846)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand -->
Add UC Registered Models support to Databricks Asset Bundles as new
resource `registered_model`. Also added UC Permission support via new
resource `grant`.

## Tests
<!-- How is this tested? -->
Tested via unit tests and manual testing with [example
PR](https://github.com/databricks/bundle-examples-internal/pull/80) and
[custom Terraform
provider](databricks/terraform-provider-databricks#2771).
<img width="698" alt="Screenshot 2023-10-08 at 4 57 23 PM"
src="https://github.com/databricks/cli/assets/87999496/bcf605a9-7894-443b-865a-f7e240037815">
<img width="1109" alt="Screenshot 2023-10-08 at 4 56 47 PM"
src="https://github.com/databricks/cli/assets/87999496/e4d6e424-cd70-4809-8843-6939ed2e172f">
<img width="1091" alt="Screenshot 2023-10-08 at 4 56 57 PM"
src="https://github.com/databricks/cli/assets/87999496/88ebaabb-67db-4a11-88a5-df087e2e41c0">

---------

Signed-off-by: Arpit Jasapara <[email protected]>
Co-authored-by: Andrew Nester <[email protected]>
Co-authored-by: Pieter Noordhuis <[email protected]>
  • Loading branch information
3 people authored Oct 16, 2023
1 parent 61cf4fb commit 24cc675
Show file tree
Hide file tree
Showing 14 changed files with 302 additions and 4 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ vendor:
@echo "✓ Filling vendor folder with library code ..."
@go mod vendor

.PHONY: build vendor coverage test lint fmt
.PHONY: build vendor coverage test lint fmt

6 changes: 6 additions & 0 deletions bundle/config/mutator/process_target_mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error {
// (model serving doesn't yet support tags)
}

for i := range r.RegisteredModels {
prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_"
r.RegisteredModels[i].Name = prefix + r.RegisteredModels[i].Name
// (registered models in Unity Catalog don't yet support tags)
}

return nil
}

Expand Down
11 changes: 10 additions & 1 deletion bundle/config/mutator/process_target_mode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/tags"
sdkconfig "github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml"
Expand Down Expand Up @@ -59,6 +60,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle {
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
"servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}},
},
RegisteredModels: map[string]*resources.RegisteredModel{
"registeredmodel1": {CreateRegisteredModelRequest: &catalog.CreateRegisteredModelRequest{Name: "registeredmodel1"}},
},
},
},
// Use AWS implementation for testing.
Expand Down Expand Up @@ -86,6 +90,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) {
// Experiment 1
assert.Equal(t, "/Users/[email protected]/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name)
assert.Contains(t, bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags, ml.ExperimentTag{Key: "dev", Value: "lennart"})
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)

// Experiment 2
assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name)
Expand All @@ -96,7 +101,9 @@ func TestProcessTargetModeDevelopment(t *testing.T) {

// Model serving endpoint 1
assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)

// Registered model 1
assert.Equal(t, "dev_lennart_registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
}

func TestProcessTargetModeDevelopmentTagNormalizationForAws(t *testing.T) {
Expand Down Expand Up @@ -151,6 +158,7 @@ func TestProcessTargetModeDefault(t *testing.T) {
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
}

func TestProcessTargetModeProduction(t *testing.T) {
Expand Down Expand Up @@ -187,6 +195,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
}

func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) {
Expand Down
17 changes: 17 additions & 0 deletions bundle/config/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Resources struct {
Models map[string]*resources.MlflowModel `json:"models,omitempty"`
Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"`
ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"`
RegisteredModels map[string]*resources.RegisteredModel `json:"registered_models,omitempty"`
}

type UniqueResourceIdTracker struct {
Expand Down Expand Up @@ -107,6 +108,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker,
tracker.Type[k] = "model_serving_endpoint"
tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath
}
for k := range r.RegisteredModels {
if _, ok := tracker.Type[k]; ok {
return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)",
k,
tracker.Type[k],
tracker.ConfigPath[k],
"registered_model",
r.RegisteredModels[k].ConfigFilePath,
)
}
tracker.Type[k] = "registered_model"
tracker.ConfigPath[k] = r.RegisteredModels[k].ConfigFilePath
}
return tracker, nil
}

Expand All @@ -129,6 +143,9 @@ func (r *Resources) SetConfigFilePath(path string) {
for _, e := range r.ModelServingEndpoints {
e.ConfigFilePath = path
}
for _, e := range r.RegisteredModels {
e.ConfigFilePath = path
}
}

// Merge iterates over all resources and merges chunks of the
Expand Down
9 changes: 9 additions & 0 deletions bundle/config/resources/grant.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package resources

// Grant holds the grant level settings for a single principal in Unity Catalog.
// Multiple of these can be defined on any Unity Catalog resource.
type Grant struct {
Privileges []string `json:"privileges"`

Principal string `json:"principal"`
}
4 changes: 2 additions & 2 deletions bundle/config/resources/model_serving_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ type ModelServingEndpoint struct {
// as a reference in other resources. This value is returned by terraform.
ID string

// Local path where the bundle is defined. All bundle resources include
// this for interpolation purposes.
// Path to config file where the resource is defined. All bundle resources
// include this for interpolation purposes.
paths.Paths

// This is a resource agnostic implementation of permissions for ACLs.
Expand Down
34 changes: 34 additions & 0 deletions bundle/config/resources/registered_model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package resources

import (
"github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/databricks-sdk-go/marshal"
"github.com/databricks/databricks-sdk-go/service/catalog"
)

type RegisteredModel struct {
// This is a resource agnostic implementation of grants.
// Implementation could be different based on the resource type.
Grants []Grant `json:"grants,omitempty"`

// This represents the id which is the full name of the model
// (catalog_name.schema_name.model_name) that can be used
// as a reference in other resources. This value is returned by terraform.
ID string

// Path to config file where the resource is defined. All bundle resources
// include this for interpolation purposes.
paths.Paths

// This represents the input args for terraform, and will get converted
// to a HCL representation for CRUD
*catalog.CreateRegisteredModelRequest
}

func (s *RegisteredModel) UnmarshalJSON(b []byte) error {
return marshal.Unmarshal(b, s)
}

func (s RegisteredModel) MarshalJSON() ([]byte, error) {
return marshal.Marshal(s)
}
30 changes: 30 additions & 0 deletions bundle/config/resources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,33 @@ func TestVerifySafeMergeForSameResourceType(t *testing.T) {
err := r.VerifySafeMerge(&other)
assert.ErrorContains(t, err, "multiple resources named foo (job at foo.yml, job at foo2.yml)")
}

func TestVerifySafeMergeForRegisteredModels(t *testing.T) {
r := Resources{
Jobs: map[string]*resources.Job{
"foo": {
Paths: paths.Paths{
ConfigFilePath: "foo.yml",
},
},
},
RegisteredModels: map[string]*resources.RegisteredModel{
"bar": {
Paths: paths.Paths{
ConfigFilePath: "bar.yml",
},
},
},
}
other := Resources{
RegisteredModels: map[string]*resources.RegisteredModel{
"bar": {
Paths: paths.Paths{
ConfigFilePath: "bar2.yml",
},
},
},
}
err := r.VerifySafeMerge(&other)
assert.ErrorContains(t, err, "multiple resources named bar (registered_model at bar.yml, registered_model at bar2.yml)")
}
36 changes: 36 additions & 0 deletions bundle/deploy/terraform/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ func convPermission(ac resources.Permission) schema.ResourcePermissionsAccessCon
return dst
}

func convGrants(acl []resources.Grant) *schema.ResourceGrants {
if len(acl) == 0 {
return nil
}

resource := schema.ResourceGrants{}
for _, ac := range acl {
resource.Grant = append(resource.Grant, schema.ResourceGrantsGrant{
Privileges: ac.Privileges,
Principal: ac.Principal,
})
}

return &resource
}

// BundleToTerraform converts resources in a bundle configuration
// to the equivalent Terraform JSON representation.
//
Expand Down Expand Up @@ -174,6 +190,19 @@ func BundleToTerraform(config *config.Root) *schema.Root {
}
}

for k, src := range config.Resources.RegisteredModels {
noResources = false
var dst schema.ResourceRegisteredModel
conv(src, &dst)
tfroot.Resource.RegisteredModel[k] = &dst

// Configure permissions for this resource.
if rp := convGrants(src.Grants); rp != nil {
rp.Function = fmt.Sprintf("${databricks_registered_model.%s.id}", k)
tfroot.Resource.Grants["registered_model_"+k] = rp
}
}

// We explicitly set "resource" to nil to omit it from a JSON encoding.
// This is required because the terraform CLI requires >= 1 resources defined
// if the "resource" property is used in a .tf.json file.
Expand Down Expand Up @@ -221,7 +250,14 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error {
cur := config.Resources.ModelServingEndpoints[resource.Name]
conv(tmp, &cur)
config.Resources.ModelServingEndpoints[resource.Name] = cur
case "databricks_registered_model":
var tmp schema.ResourceRegisteredModel
conv(resource.AttributeValues, &tmp)
cur := config.Resources.RegisteredModels[resource.Name]
conv(tmp, &cur)
config.Resources.RegisteredModels[resource.Name] = cur
case "databricks_permissions":
case "databricks_grants":
// Ignore; no need to pull these back into the configuration.
default:
return fmt.Errorf("missing mapping for %s", resource.Type)
Expand Down
56 changes: 56 additions & 0 deletions bundle/deploy/terraform/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml"
Expand Down Expand Up @@ -366,3 +367,58 @@ func TestConvertModelServingPermissions(t *testing.T) {
assert.Equal(t, "CAN_VIEW", p.PermissionLevel)

}

func TestConvertRegisteredModel(t *testing.T) {
var src = resources.RegisteredModel{
CreateRegisteredModelRequest: &catalog.CreateRegisteredModelRequest{
Name: "name",
CatalogName: "catalog",
SchemaName: "schema",
Comment: "comment",
},
}

var config = config.Root{
Resources: config.Resources{
RegisteredModels: map[string]*resources.RegisteredModel{
"my_registered_model": &src,
},
},
}

out := BundleToTerraform(&config)
resource := out.Resource.RegisteredModel["my_registered_model"]
assert.Equal(t, "name", resource.Name)
assert.Equal(t, "catalog", resource.CatalogName)
assert.Equal(t, "schema", resource.SchemaName)
assert.Equal(t, "comment", resource.Comment)
assert.Nil(t, out.Data)
}

func TestConvertRegisteredModelGrants(t *testing.T) {
var src = resources.RegisteredModel{
Grants: []resources.Grant{
{
Privileges: []string{"EXECUTE"},
Principal: "[email protected]",
},
},
}

var config = config.Root{
Resources: config.Resources{
RegisteredModels: map[string]*resources.RegisteredModel{
"my_registered_model": &src,
},
},
}

out := BundleToTerraform(&config)
assert.NotEmpty(t, out.Resource.Grants["registered_model_my_registered_model"].Function)
assert.Len(t, out.Resource.Grants["registered_model_my_registered_model"].Grant, 1)

p := out.Resource.Grants["registered_model_my_registered_model"].Grant[0]
assert.Equal(t, "[email protected]", p.Principal)
assert.Equal(t, "EXECUTE", p.Privileges[0])

}
3 changes: 3 additions & 0 deletions bundle/deploy/terraform/interpolate.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri
case "model_serving_endpoints":
path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
case "registered_models":
path = strings.Join(append([]string{"databricks_registered_model"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
default:
panic("TODO: " + parts[1])
}
Expand Down
18 changes: 18 additions & 0 deletions bundle/schema/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,19 @@ func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) {
return modelServingEndpointsAllDocs, nil
}

func (reader *OpenapiReader) registeredModelDocs() (*Docs, error) {
registeredModelsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "catalog.CreateRegisteredModelRequest")
if err != nil {
return nil, err
}
registeredModelsDocs := schemaToDocs(registeredModelsSpecSchema)
registeredModelsAllDocs := &Docs{
Description: "List of Registered Models",
AdditionalProperties: registeredModelsDocs,
}
return registeredModelsAllDocs, nil
}

func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
jobsDocs, err := reader.jobsDocs()
if err != nil {
Expand All @@ -244,6 +257,10 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
if err != nil {
return nil, err
}
registeredModelsDocs, err := reader.registeredModelDocs()
if err != nil {
return nil, err
}

return &Docs{
Description: "Collection of Databricks resources to deploy.",
Expand All @@ -253,6 +270,7 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
"experiments": experimentsDocs,
"models": modelsDocs,
"model_serving_endpoints": modelServingEndpointsDocs,
"registered_models": registeredModelsDocs,
},
}, nil
}
Loading

0 comments on commit 24cc675

Please sign in to comment.