Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Prompt studio Coverage #907

Merged
merged 7 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions backend/prompt_studio/prompt_studio_core_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,48 +44,71 @@ class Meta:

def to_representation(self, instance): # type: ignore
data = super().to_representation(instance)
default_profile = None

# Fetch summarize LLM profile
try:
profile_manager = ProfileManager.objects.get(
summarize_profile = ProfileManager.objects.get(
prompt_studio_tool=instance, is_summarize_llm=True
)
data[TSKeys.SUMMARIZE_LLM_PROFILE] = profile_manager.profile_id
data[TSKeys.SUMMARIZE_LLM_PROFILE] = summarize_profile.profile_id
except ObjectDoesNotExist:
logger.info(
"Summarize LLM profile doesnt exist for prompt tool %s",
"Summarize LLM profile doesn't exist for prompt tool %s",
str(instance.tool_id),
)

# Fetch default LLM profile
try:
profile_manager = ProfileManager.get_default_llm_profile(instance)
data[TSKeys.DEFAULT_PROFILE] = profile_manager.profile_id
default_profile = ProfileManager.get_default_llm_profile(instance)
data[TSKeys.DEFAULT_PROFILE] = default_profile.profile_id
except DefaultProfileError:
# To make it compatible with older projects error suppressed with warning.
logger.warning(
"Default LLM profile doesnt exist for prompt tool %s",
"Default LLM profile doesn't exist for prompt tool %s",
str(instance.tool_id),
)
prompt_instance: ToolStudioPrompt = ToolStudioPrompt.objects.filter(

# Fetch prompt instances
prompt_instances: ToolStudioPrompt = ToolStudioPrompt.objects.filter(
tool_id=data.get(TSKeys.TOOL_ID)
).order_by("sequence_number")
data[TSKeys.PROMPTS] = []

if not prompt_instances.exists():
data[TSKeys.PROMPTS] = []
return data

# Process prompt instances
output: list[Any] = []
# Appending prompt instances of the tool for FE Processing
if prompt_instance.count() != 0:
for prompt in prompt_instance:
profile_manager_id = prompt.prompt_id
if instance.single_pass_extraction_mode:
# use projects default profile
profile_manager_id = profile_manager.profile_id
prompt_serializer = ToolStudioPromptSerializer(prompt)
for prompt in prompt_instances:
prompt_serializer = ToolStudioPromptSerializer(prompt)
serialized_data = prompt_serializer.data

# Determine coverage
coverage: list[Any] = []
profile_manager_id = prompt.profile_manager
if default_profile and instance.single_pass_extraction_mode:
jagadeeswaran-zipstack marked this conversation as resolved.
Show resolved Hide resolved
harini-venkataraman marked this conversation as resolved.
Show resolved Hide resolved
profile_manager_id = default_profile.profile_id

if profile_manager_id:
coverage = OutputManagerUtils.get_coverage(
data.get(TSKeys.TOOL_ID),
profile_manager_id,
prompt.prompt_id,
instance.single_pass_extraction_mode,
)
serialized_data = prompt_serializer.data
serialized_data["coverage"] = coverage
output.append(serialized_data)
data[TSKeys.PROMPTS] = output
else:
logger.info(
jagadeeswaran-zipstack marked this conversation as resolved.
Show resolved Hide resolved
"Skipping coverage calculation for prompt %s "
"due to missing profile ID",
str(prompt.prompt_key),
)

# Add coverage to serialized data
serialized_data["coverage"] = coverage
output.append(serialized_data)

data[TSKeys.PROMPTS] = output
data["created_by_email"] = instance.created_by.email

return data
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from django.db.models import Count
from prompt_studio.prompt_studio_output_manager_v2.models import (
PromptStudioOutputManager,
)
Expand All @@ -11,41 +10,33 @@ def get_coverage(
profile_manager_id: str,
prompt_id: str = None,
is_single_pass: bool = False,
) -> dict[str, int]:
) -> list[str]:
"""
Method to fetch coverage data for given tool and profile manager.

Args:
tool (CustomTool): The tool instance or ID for which coverage is fetched.
tool_id (str): The ID of the tool for which coverage is fetched.
profile_manager_id (str): The ID of the profile manager
for which coverage is calculated.
prompt_id (Optional[str]): The ID of the prompt (optional).
is_single_pass (Optional[bool]): Singlepass enabled or not
is_single_pass (Optional[bool]): Singlepass enabled or not.
If provided, coverage is fetched for the specific prompt.

Returns:
dict[str, int]: A dictionary containing coverage information.
dict[str, list[str]]: A dictionary containing coverage information.
Keys are formatted as "coverage_<prompt_id>_<profile_manager_id>".
Values are the count of documents associated with each prompt
Values are lists of document IDs associated with each prompt
and profile combination.
"""
# TODO: remove singlepass reference
prompt_outputs = (
PromptStudioOutputManager.objects.filter(
tool_id=tool_id,
profile_manager_id=profile_manager_id,
prompt_id=prompt_id,
is_single_pass_extract=is_single_pass,
)
.values("prompt_id", "profile_manager_id")
.annotate(document_count=Count("document_manager_id"))
)
prompt_outputs = PromptStudioOutputManager.objects.filter(
tool_id=tool_id,
profile_manager_id=profile_manager_id,
prompt_id=prompt_id,
is_single_pass_extract=is_single_pass,
).values("prompt_id", "profile_manager_id", "document_manager_id")

coverage = {}
coverage = []
for prompt_output in prompt_outputs:
prompt_key = str(prompt_output["prompt_id"])
profile_key = str(prompt_output["profile_manager_id"])
coverage[f"coverage_{prompt_key}_{profile_key}"] = prompt_output[
"document_count"
]
coverage.append(str(prompt_output["document_manager_id"]))
return coverage
Original file line number Diff line number Diff line change
Expand Up @@ -180,29 +180,6 @@ function DocumentParser({
return outputs;
};

const getPromptCoverageCount = (promptId) => {
const keys = Object.keys(promptOutputs || {});
const coverageKey = `coverage_${promptId}`;
const outputs = {};
if (!keys?.length) {
details?.prompts?.forEach((prompt) => {
if (prompt?.coverage) {
const key = Object.keys(prompt?.coverage)[0];
if (key?.startsWith(coverageKey)) {
outputs[key] = prompt?.coverage[key];
}
}
});
return outputs;
}
keys?.forEach((key) => {
if (key?.startsWith(coverageKey)) {
outputs[key] = promptOutputs[key];
}
});
return outputs;
};

if (!details?.prompts?.length) {
if (isSimplePromptStudio && SpsPromptsEmptyState) {
return <SpsPromptsEmptyState />;
Expand Down Expand Up @@ -230,7 +207,7 @@ function DocumentParser({
outputs={getPromptOutputs(item?.prompt_id)}
enforceTypeList={enforceTypeList}
setUpdatedPromptsCopy={setUpdatedPromptsCopy}
coverageCountData={getPromptCoverageCount(item?.prompt_id)}
coverageCountData={item?.coverage}
isChallenge={isChallenge}
/>
<div ref={bottomRef} className="doc-parser-pad-bottom" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import SpaceWrapper from "../../widgets/space-wrapper/SpaceWrapper";
import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader";
import "./ManageDocsModal.css";
import usePostHogEvents from "../../../hooks/usePostHogEvents";
import { usePromptOutputStore } from "../../../store/prompt-output-store";

let SummarizeStatusTitle = null;
try {
Expand Down Expand Up @@ -90,6 +91,7 @@ function ManageDocsModal({
const axiosPrivate = useAxiosPrivate();
const handleException = useExceptionHandler();
const { setPostHogCustomEvent } = usePostHogEvents();
const { promptOutputs, updatePromptOutput } = usePromptOutputStore();

const successIndex = (
<Typography.Text>
Expand Down Expand Up @@ -543,21 +545,32 @@ function ManageDocsModal({
);
updateCustomTool({ listOfDocs: newListOfDocs });

if (newListOfDocs?.length === 1 && selectedDoc?.document_id !== docId) {
const doc = newListOfDocs[1];
if (selectedDoc?.document_id === docId) {
const doc = newListOfDocs[0];
handleDocChange(doc);
}

if (docId === selectedDoc?.document_id) {
updateCustomTool({ selectedDoc: "" });
handleUpdateTool({ output: "" });
}
const updatedPromptOutput = removeIdFromCoverage(promptOutputs, docId);
updatePromptOutput(updatedPromptOutput);
})
.catch((err) => {
setAlertDetails(handleException(err, "Failed to delete"));
});
};

const removeIdFromCoverage = (data, idToRemove) => {
return Object.entries(data).reduce((updatedData, [key, value]) => {
// Create a new object for the current entry
updatedData[key] = {
...value,
// Update the coverage array if it exists
coverage: value?.coverage
? value?.coverage?.filter((id) => id !== idToRemove)
: value?.coverage,
jagadeeswaran-zipstack marked this conversation as resolved.
Show resolved Hide resolved
};
return updatedData;
}, {});
};

return (
<Modal
className="pre-post-amble-modal"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { Header } from "./Header";
import { OutputForIndex } from "./OutputForIndex";
import { PromptOutput } from "./PromptOutput";
import { TABLE_ENFORCE_TYPE, RECORD_ENFORCE_TYPE } from "./constants";
import { generateCoverageKey } from "../../../helpers/GetStaticData";
import usePromptOutput from "../../../hooks/usePromptOutput";

let TableExtractionSettingsBtn;
try {
Expand Down Expand Up @@ -66,6 +66,8 @@ function PromptCardItems({
defaultLlmProfile,
singlePassExtractMode,
} = useCustomToolStore();

const { generatePromptOutputKey } = usePromptOutput();
const [isEditingPrompt, setIsEditingPrompt] = useState(false);
const [isEditingTitle, setIsEditingTitle] = useState(false);
const [expandCard, setExpandCard] = useState(true);
Expand All @@ -78,10 +80,18 @@ function PromptCardItems({
const isNotSingleLlmProfile = llmProfiles.length > 1;
const divRef = useRef(null);
const [enforceType, setEnforceType] = useState("");
const profileId = singlePassExtractMode
? defaultLlmProfile
: selectedLlmProfileId || defaultLlmProfile;
const coverageKey = generateCoverageKey(promptDetails?.prompt_id, profileId);
const promptId = promptDetails?.prompt_id;
const docId = selectedDoc?.document_id;
const promptProfile = promptDetails?.profile_manager || defaultLlmProfile;
const promptOutputKey = generatePromptOutputKey(
promptId,
docId,
promptProfile,
singlePassExtractMode,
true
);
const promptCoverage =
promptOutputs[promptOutputKey]?.coverage || coverageCountData;
tahierhussain marked this conversation as resolved.
Show resolved Hide resolved

useEffect(() => {
if (enforceType !== promptDetails?.enforce_type) {
Expand Down Expand Up @@ -213,7 +223,7 @@ function PromptCardItems({
<SearchOutlined className="font-size-12" />
)}
<Typography.Link className="font-size-12">
Coverage: {coverageCountData[coverageKey] || 0} of{" "}
Coverage: {promptCoverage?.length || 0} of{" "}
{listOfDocs?.length || 0} docs
</Typography.Link>
</Space>
Expand Down
5 changes: 1 addition & 4 deletions frontend/src/hooks/usePromptOutput.js
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ const usePromptOutput = () => {

let isTokenUsageForSinglePassAdded = false;
const tokenUsageDetails = {};

data.forEach((item) => {
const promptId = item?.prompt_id;
const docId = item?.document_manager;
Expand All @@ -109,7 +108,6 @@ const usePromptOutput = () => {
isSinglePass,
true
);
const coverageKey = `coverage_${item?.prompt_id}_${llmProfile}`;
outputs[key] = {
runId: item?.run_id,
promptOutputId: item?.prompt_output_id,
Expand All @@ -119,8 +117,8 @@ const usePromptOutput = () => {
tokenUsage: item?.token_usage,
output: item?.output,
timer,
coverage: item?.coverage,
};
outputs[coverageKey] = item?.coverage[coverageKey] || 0;

if (item?.is_single_pass_extract && isTokenUsageForSinglePassAdded)
return;
Expand Down Expand Up @@ -150,7 +148,6 @@ const usePromptOutput = () => {
);
tokenUsageDetails[tokenUsageId] = item?.token_usage;
});

if (isReset) {
setPromptOutput(outputs);
setTokenUsage(tokenUsageDetails);
Expand Down
Loading