Skip to content

Commit

Permalink
Multitenancy v2 changes for prompt studio Export (#758)
Browse files Browse the repository at this point in the history
added the changes to v2

Co-authored-by: Hari John Kuriakose <[email protected]>
  • Loading branch information
jagadeeswaran-zipstack and hari-kuriakose authored Sep 30, 2024
1 parent 06128e0 commit bac1a78
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 29 deletions.
3 changes: 3 additions & 0 deletions backend/prompt_studio/prompt_studio_core_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,15 @@ def export_tool(self, request: Request, pk: Any = None) -> Response:
serializer.is_valid(raise_exception=True)
is_shared_with_org: bool = serializer.validated_data.get("is_shared_with_org")
user_ids = set(serializer.validated_data.get("user_id"))
force_export = serializer.validated_data.get("force_export")

PromptStudioRegistryHelper.update_or_create_psr_tool(
custom_tool=custom_tool,
shared_with_org=is_shared_with_org,
user_ids=user_ids,
force_export=force_export,
)

return Response(
{"message": "Custom tool exported sucessfully."},
status=status.HTTP_200_OK,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,6 @@ def frame_export_json(
invalidated_prompts.append(prompt.prompt_key)
continue

if not force_export:
prompt_output = PromptStudioOutputManager.objects.filter(
tool_id=tool.tool_id,
prompt_id=prompt.prompt_id,
profile_manager=prompt.profile_manager,
).all()
if not prompt_output:
invalidated_outputs.append(prompt.prompt_key)
continue

if not prompt.profile_manager:
prompt.profile_manager = default_llm_profile

Expand Down
4 changes: 2 additions & 2 deletions backend/prompt_studio/prompt_studio_registry_v2/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class EmptyToolExportError(APIException):
status_code = 500
default_detail = (
"Prompt Studio project without prompts cannot be exported. "
"Please ensure there is at least one prompt and "
"it is active before exporting."
"Please ensure there is at least one active prompt "
"that has been run before exporting."
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,24 @@ def get_tool_by_prompt_registry_id(

@staticmethod
def update_or_create_psr_tool(
custom_tool: CustomTool, shared_with_org: bool, user_ids: set[int]
custom_tool: CustomTool,
shared_with_org: bool,
user_ids: set[int],
force_export: bool,
) -> PromptStudioRegistry:
"""Updates or creates the PromptStudioRegistry record.
This appears as a separate tool in the workflow and is mapped
1:1 with the `CustomTool`.
Args:
tool_id (str): ID of the custom tool.
custom_tool (CustomTool): The instance of the custom tool to be updated
or created.
shared_with_org (bool): Flag indicating whether the tool is shared with
the organization.
user_ids (set[int]): A set of user IDs to whom the tool is shared.
force_export (bool): Indicates if the export is being forced.
Raises:
ToolSaveError
Expand All @@ -162,7 +171,7 @@ def update_or_create_psr_tool(
tool_id=custom_tool.tool_id
)
metadata = PromptStudioRegistryHelper.frame_export_json(
tool=custom_tool, prompts=prompts
tool=custom_tool, prompts=prompts, force_export=force_export
)

obj: PromptStudioRegistry
Expand Down Expand Up @@ -208,7 +217,9 @@ def update_or_create_psr_tool(

@staticmethod
def frame_export_json(
tool: CustomTool, prompts: list[ToolStudioPrompt]
tool: CustomTool,
prompts: list[ToolStudioPrompt],
force_export: bool,
) -> dict[str, Any]:
export_metadata = {}

Expand Down Expand Up @@ -283,19 +294,19 @@ def frame_export_json(
invalidated_prompts.append(prompt.prompt_key)
continue

prompt_output = PromptStudioOutputManager.objects.filter(
tool_id=tool.tool_id,
prompt_id=prompt.prompt_id,
profile_manager=prompt.profile_manager,
).all()

if not prompt_output:
invalidated_outputs.append(prompt.prompt_key)
continue

if not prompt.profile_manager:
prompt.profile_manager = default_llm_profile

if not force_export:
prompt_output = PromptStudioOutputManager.objects.filter(
tool_id=tool.tool_id,
prompt_id=prompt.prompt_id,
profile_manager=prompt.profile_manager,
).all()
if not prompt_output:
invalidated_outputs.append(prompt.prompt_key)
continue

vector_db = str(prompt.profile_manager.vector_store.id)
embedding_model = str(prompt.profile_manager.embedding_model.id)
llm = str(prompt.profile_manager.llm.id)
Expand Down Expand Up @@ -354,10 +365,12 @@ def frame_export_json(
f"Cannot export tool. Prompt(s): {', '.join(invalidated_prompts)} "
"are empty. Please enter a valid prompt."
)
if invalidated_outputs:
if not force_export and invalidated_outputs:
raise InValidCustomToolError(
f"Cannot export tool. Prompt(s): {', '.join(invalidated_outputs)} "
"were not run. Please run them before exporting."
detail="Cannot export tool. Prompt(s):"
f" {', '.join(invalidated_outputs)}"
" were not run. Please run them before exporting.",
code="warning",
)
export_metadata[JsonSchemaKey.TOOL_SETTINGS] = tool_settings
export_metadata[JsonSchemaKey.OUTPUTS] = outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ def get_prompt_studio_users(self, obj: PromptStudioRegistry) -> Any:
class ExportToolRequestSerializer(serializers.Serializer):
is_shared_with_org = serializers.BooleanField(default=False)
user_id = serializers.ListField(child=serializers.IntegerField(), required=False)
force_export = serializers.BooleanField(default=False)

0 comments on commit bac1a78

Please sign in to comment.