Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Option to disable cross encoder models #286

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
da037d1
Currently cross encoder models are used to rank the search results bu…
azaylamba Dec 23, 2023
1c3b8ce
Enhancement: Add user feedback for responses
azaylamba Dec 24, 2023
c8dc554
Revert "Enhancement: Add user feedback for responses"
azaylamba Dec 24, 2023
550d2d0
Merge branch 'main' into main
azaylamba Jan 17, 2024
8dd11d8
Merge branch 'aws-samples:main' into main
azaylamba Jan 25, 2024
42c6edd
Merge branch 'main' of https://github.com/azaylamba/aws-genai-llm-cha…
azaylamba Feb 4, 2024
efb1a99
Addressed review comments related to cross encoding.
azaylamba Feb 4, 2024
b58737d
Removed prompt for selecting embedding models as it is not required now.
azaylamba Feb 4, 2024
cb8793d
Resolving merge conflicts
azaylamba Feb 9, 2024
cf0dfc1
Resolving merge conflicts
azaylamba Feb 9, 2024
13ce71e
Derived value of crossEncodingEnabled based on enableEmbeddingModelsV…
azaylamba Feb 9, 2024
2522839
Reverted unwanted change
azaylamba Feb 9, 2024
4669419
Merge branch 'main' into main
bigadsoleiman Feb 13, 2024
1667e9c
Merge branch 'main' into main
azaylamba Feb 24, 2024
1102491
Default embeddings model prompt was not set
azaylamba Feb 24, 2024
2047641
Merge branch 'main' into main
bigadsoleiman Mar 8, 2024
a09713e
Merge branch 'main' into main
azaylamba Apr 13, 2024
dca47d0
Corrected the NagSuppression conditions
azaylamba Apr 20, 2024
c2eabf4
Merge branch 'main' into main
azaylamba Jul 13, 2024
6a7c92b
Addressed review comments
azaylamba Jul 13, 2024
494f3b1
Added default value for cross encoder models
azaylamba Jul 15, 2024
efa9fa8
Merge branch 'main' into main
azaylamba Jul 18, 2024
61b73d2
Used enableSagemakerModels config for SM models
azaylamba Jul 18, 2024
feb5752
Merge branch 'main' of https://github.com/azaylamba/aws-genai-llm-cha…
azaylamba Jul 18, 2024
6850a9a
Merge branch 'main' into main
azaylamba Aug 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export function getConfig(): SystemConfig {
},
llms: {
// sagemaker: [SupportedSageMakerModels.FalconLite]
enableSagemakerModels: false,
sagemaker: [],
},
rag: {
Expand Down Expand Up @@ -64,6 +65,7 @@ export function getConfig(): SystemConfig {
default: true,
},
],
crossEncodingEnabled: false,
},
};
}
Expand Down
40 changes: 34 additions & 6 deletions cli/magic-create.ts
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ async function processCreateOptions(options: any): Promise<void> {
},
initial: options.bedrockRoleArn || "",
},
{
type: "multiselect",
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
name: "selectedEmbeddingModels",
hint: "SPACE to select, ENTER to confirm selection",
message: "Which Embedding Models do you want to enable",
choices: embeddingModels.map((m) => ({ name: m.name, value: m })),
validate(choices: any) {
return (this as any).skipped || choices.length > 0
? true
: "You need to select at least one model";
},
},
{
type: "confirm",
name: "enableSagemakerModels",
Expand Down Expand Up @@ -180,6 +192,12 @@ async function processCreateOptions(options: any): Promise<void> {
message: "Do you want to enable RAG",
initial: options.enableRag || false,
},
{
type: "confirm",
name: "enableCrossEncoding",
message: "Do you want to enable Cross-Encoding",
initial: options.enableCrossEncoding || false,
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
},
{
type: "multiselect",
name: "ragsToEnable",
Expand Down Expand Up @@ -319,6 +337,7 @@ async function processCreateOptions(options: any): Promise<void> {
}
: undefined,
llms: {
enableSagemakerModels: answers.enableSagemakerModels,
sagemaker: answers.sagemakerModels,
},
rag: {
Expand All @@ -339,6 +358,7 @@ async function processCreateOptions(options: any): Promise<void> {
},
embeddingsModels: [{}],
crossEncoderModels: [{}],
crossEncodingEnabled: answers.enableCrossEncoding,
},
};

Expand All @@ -347,12 +367,20 @@ async function processCreateOptions(options: any): Promise<void> {
models.defaultEmbedding = embeddingModels[0].name;
}

config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.embeddingsModels = embeddingModels;
if (answers.enableCrossEncoding && answers.sagemakerModels.length > 0) {
config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
} else {
config.rag.crossEncoderModels[0] = {
provider: "None",
name: "None",
default: true,
};
}
config.rag.embeddingsModels = embeddingModels.filter((model) => answers.selectedEmbeddingModels.includes(model.name));
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
Expand Down
3 changes: 1 addition & 2 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
identityPool: authentication.identityPool,
api: chatBotApi,
chatbotFilesBucket: chatBotApi.filesBucket,
crossEncodersEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
crossEncodersEnabled: props.config.rag.crossEncodingEnabled,
sagemakerEmbeddingsEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
});
Expand Down
27 changes: 16 additions & 11 deletions lib/chatbot-api/functions/api-handler/routes/cross_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@ def models():
@tracer.capture_method
def cross_encoders(input: dict):
request = CrossEncodersRequest(**input)
selected_model = genai_core.cross_encoder.get_cross_encoder_model(
request.provider, request.model
)

if selected_model is None:
raise genai_core.types.CommonError("Model not found")

ret_value = genai_core.cross_encoder.rank_passages(
selected_model, request.reference, request.passages
)
return [{"score": v, "passage": p} for v, p in zip(ret_value, request.passages)]
config = genai_core.parameters.get_config()
crossEncodingEnabled = config["rag"]["crossEncodingEnabled"]
if (crossEncodingEnabled):
selected_model = genai_core.cross_encoder.get_cross_encoder_model(
request.provider, request.model
)

if selected_model is None:
raise genai_core.types.CommonError("Model not found")

ret_value = genai_core.cross_encoder.rank_passages(
selected_model, request.reference, request.passages
)
return [{"score": v, "passage": p} for v, p in zip(ret_value, request.passages)]

return [{"score": 0, "passage": p} for p in request.passages]
2 changes: 1 addition & 1 deletion lib/rag-engines/data-import/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export class DataImport extends Construct {
processingBucket,
auroraDatabase: props.auroraDatabase,
ragDynamoDBTables: props.ragDynamoDBTables,
sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint,
sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model?.endpoint,
openSearchVector: props.openSearchVector,
}
);
Expand Down
5 changes: 1 addition & 4 deletions lib/rag-engines/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ export class RagEngines extends Construct {
const tables = new RagDynamoDBTables(this, "RagDynamoDBTables");

let sageMakerRagModels: SageMakerRagModels | null = null;
if (
props.config.rag.engines.aurora.enabled ||
props.config.rag.engines.opensearch.enabled
) {
if (props.config.llms.enableSagemakerModels) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be checking crossEncodingEnabled and not enableSageMakerModels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but won't that be confusing that crossEncodingEnabled is driving the Sagemaker models instead of the config props.config.llms.enableSagemakerModels which is specific for sagemaker models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right and props.config.llms.enableSagemakerModels is better

sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", {
shared: props.shared,
config: props.config,
Expand Down
30 changes: 16 additions & 14 deletions lib/rag-engines/sagemaker-rag-models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,22 @@ export class SageMakerRagModels extends Construct {
.filter((c) => c.provider === "sagemaker")
.map((c) => c.name);

const model = new SageMakerModel(this, "Model", {
vpc: props.shared.vpc,
region: cdk.Aws.REGION,
model: {
type: DeploymentType.CustomInferenceScript,
modelId: [
...sageMakerEmbeddingsModelIds,
...sageMakerCrossEncoderModelIds,
],
codeFolder: path.join(__dirname, "./model"),
instanceType: "ml.g4dn.xlarge",
},
});
if (sageMakerEmbeddingsModelIds?.length > 0 || sageMakerCrossEncoderModelIds?.length > 0) {
const model = new SageMakerModel(this, "Model", {
vpc: props.shared.vpc,
region: cdk.Aws.REGION,
model: {
type: DeploymentType.CustomInferenceScript,
modelId: [
...sageMakerEmbeddingsModelIds,
...sageMakerCrossEncoderModelIds,
],
codeFolder: path.join(__dirname, "./model"),
instanceType: "ml.g4dn.xlarge",
},
});

this.model = model;
this.model = model;
}
}
}
121 changes: 57 additions & 64 deletions lib/shared/layers/python-sdk/python/genai_core/aurora/query.py
azaylamba marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def query_workspace_aurora(
full_response: bool,
threshold: int = 0,
):
config = genai_core.parameters.get_config()
table_name = sql.Identifier(workspace_id.replace("-", ""))
embeddings_model_provider = workspace["embeddings_model_provider"]
embeddings_model_name = workspace["embeddings_model_name"]
Expand All @@ -37,12 +38,13 @@ def query_workspace_aurora(
if selected_model is None:
raise genai_core.types.CommonError("Embeddings model not found")

cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model(
cross_encoder_model_provider, cross_encoder_model_name
)
if (config["rag"]["crossEncodingEnabled"]):
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model(
cross_encoder_model_provider, cross_encoder_model_name
)

if cross_encoder_model is None:
raise genai_core.types.CommonError("Cross encoder model not found")
if cross_encoder_model is None:
raise genai_core.types.CommonError("Cross encoder model not found")

query_embeddings = genai_core.embeddings.generate_embeddings(
selected_model, [query]
Expand Down Expand Up @@ -185,78 +187,69 @@ def query_workspace_aurora(
item["keyword_search_score"] = current["keyword_search_score"]

unique_items = list(unique_items.values())
score_dict = dict({})
if len(unique_items) > 0:
passages = [record["content"] for record in unique_items]
passage_scores = genai_core.cross_encoder.rank_passages(
cross_encoder_model, query, passages
)

for i in range(len(unique_items)):
score = passage_scores[i]
unique_items[i]["score"] = score
score_dict[unique_items[i]["chunk_id"]] = score

unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True)

for record in vector_search_records:
record["score"] = score_dict[record["chunk_id"]]
for record in keyword_search_records:
record["score"] = score_dict[record["chunk_id"]]
if (config["rag"]["crossEncodingEnabled"]):
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
score_dict = dict({})
if len(unique_items) > 0:
passages = [record["content"] for record in unique_items]
passage_scores = genai_core.cross_encoder.rank_passages(
cross_encoder_model, query, passages
)

if full_response:
unique_items = unique_items[:limit]
ret_value = {
"engine": "aurora",
"query_language": language_name,
"supported_languages": languages,
"detected_languages": detected_languages,
"items": convert_types(unique_items),
"vector_search_metric": metric,
"vector_search_items": convert_types(vector_search_records),
"keyword_search_items": convert_types(keyword_search_records),
}
else:
ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[
:limit
]
if len(ret_items) < limit:
azaylamba marked this conversation as resolved.
Show resolved Hide resolved
# inner product metric is negative hence we sort ascending
if metric == "inner":
unique_items = sorted(
unique_items,
key=lambda x: x["vector_search_score"] or 1,
reverse=False,
)
ret_items = ret_items + (
list(
filter(
lambda val: (val["vector_search_score"] or 1) < -0.5,
unique_items,
)
)[: (limit - len(ret_items))]
)
else:
for i in range(len(unique_items)):
score = passage_scores[i]
unique_items[i]["score"] = score
score_dict[unique_items[i]["chunk_id"]] = score

unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True)

if (config["rag"]["crossEncodingEnabled"]):
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
for record in vector_search_records:
record["score"] = score_dict[record["chunk_id"]]
for record in keyword_search_records:
record["score"] = score_dict[record["chunk_id"]]

if full_response:
ret_value = {
"engine": "aurora",
"query_language": language_name,
"supported_languages": languages,
"detected_languages": detected_languages,
"items": convert_types(unique_items),
"vector_search_metric": metric,
"vector_search_items": convert_types(vector_search_records),
"keyword_search_items": convert_types(keyword_search_records),
}
else:
ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))
if len(ret_items) < limit:
unique_items = sorted(
unique_items,
key=lambda x: x["vector_search_score"] or -1,
reverse=True,
unique_items, key=lambda x: x["vector_search_score"], reverse=True
)
ret_items = ret_items + (
list(
filter(
lambda val: (val["vector_search_score"] or -1) > 0.5,
unique_items,
)
filter(lambda val: val["vector_search_score"] > 0.5, unique_items)
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
)[: (limit - len(ret_items))]
)

ret_value = {
"engine": "aurora",
"query_language": language_name,
"supported_languages": languages,
"detected_languages": detected_languages,
"items": convert_types(
list(filter(lambda val: val["score"] > 0, unique_items))
),
}
else:
ret_value = {
"engine": "aurora",
"query_language": language_name,
"supported_languages": languages,
"detected_languages": detected_languages,
"items": convert_types(ret_items),
"items": convert_types(unique_items),
"vector_search_metric": metric,
"vector_search_items": convert_types(vector_search_records),
"keyword_search_items": convert_types(keyword_search_records),
}

logger.info(ret_value)
Expand Down
Loading