diff --git a/scripts/populate.ts b/scripts/populate.ts index 8e06879a634..0e4405ae971 100644 --- a/scripts/populate.ts +++ b/scripts/populate.ts @@ -14,7 +14,7 @@ import type { User } from "../src/lib/types/User"; import type { Assistant } from "../src/lib/types/Assistant"; import type { Conversation } from "../src/lib/types/Conversation"; import type { Settings } from "../src/lib/types/Settings"; -import { defaultEmbeddingModel } from "../src/lib/server/embeddingModels.ts"; +import { getDefaultEmbeddingModel } from "../src/lib/server/embeddingModels.ts"; import { Message } from "../src/lib/types/Message.ts"; import { addChildren } from "../src/lib/utils/tree/addChildren.ts"; @@ -146,6 +146,7 @@ async function seed() { updatedAt: faker.date.recent({ days: 30 }), customPrompts: {}, assistants: [], + disableStream: false, }; await collections.settings.updateOne( { userId: user._id }, @@ -214,7 +215,7 @@ async function seed() { : faker.helpers.maybe(() => faker.hacker.phrase(), { probability: 0.5 })) ?? ""; const messages = await generateMessages(preprompt); - + const defaultEmbeddingModel = await getDefaultEmbeddingModel(); const conv = { _id: new ObjectId(), userId: user._id, @@ -224,7 +225,7 @@ async function seed() { updatedAt: faker.date.recent({ days: 145 }), model: faker.helpers.arrayElement(modelIds), title: faker.internet.emoji() + " " + faker.hacker.phrase(), - embeddingModel: defaultEmbeddingModel.id, + embeddingModel: defaultEmbeddingModel.name, messages, rootMessageId: messages[0].id, } satisfies Conversation; diff --git a/src/hooks.server.ts b/src/hooks.server.ts index 157f66ac072..241c184e048 100644 --- a/src/hooks.server.ts +++ b/src/hooks.server.ts @@ -16,6 +16,7 @@ import { initExitHandler } from "$lib/server/exitHandler"; import { ObjectId } from "mongodb"; import { refreshAssistantsCounts } from "$lib/jobs/refresh-assistants-counts"; import { refreshConversationStats } from "$lib/jobs/refresh-conversation-stats"; +import { pupulateEmbeddingModel } from "$lib/server/embeddingModels"; // TODO: move this code on a started server hook, instead of using a "building" flag if (!building) { @@ -25,6 +26,9 @@ if (!building) { if (env.ENABLE_ASSISTANTS) { refreshAssistantsCounts(); } + + await pupulateEmbeddingModel(); + refreshConversationStats(); // Init metrics server diff --git a/src/lib/server/database.ts b/src/lib/server/database.ts index 5e5824dcc3a..7545b1ace88 100644 --- a/src/lib/server/database.ts +++ b/src/lib/server/database.ts @@ -16,6 +16,7 @@ import type { AssistantStats } from "$lib/types/AssistantStats"; import { logger } from "$lib/server/logger"; import { building } from "$app/environment"; import { onExit } from "./exitHandler"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; export const CONVERSATION_STATS_COLLECTION = "conversations.stats"; @@ -83,6 +84,7 @@ export class Database { const bucket = new GridFSBucket(db, { bucketName: "files" }); const migrationResults = db.collection("migrationResults"); const semaphores = db.collection("semaphores"); + const embeddingModels = db.collection("embeddingModels"); return { conversations, @@ -99,6 +101,7 @@ export class Database { bucket, migrationResults, semaphores, + embeddingModels, }; } @@ -120,6 +123,7 @@ export class Database { sessions, messageEvents, semaphores, + embeddingModels, } = this.getCollections(); conversations @@ -209,6 +213,8 @@ export class Database { semaphores .createIndex({ createdAt: 1 }, { expireAfterSeconds: 60 }) .catch((e) => logger.error(e)); + + embeddingModels.createIndex({ name: 1 }, { unique: true }).catch((e) => logger.error(e)); } } diff --git a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts index 053e4316605..b26303e5ffa 100644 --- a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts @@ -12,6 +12,7 @@ import { embeddingEndpointOpenAIParametersSchema, } from "./openai/embeddingEndpoints"; import { embeddingEndpointHfApi, embeddingEndpointHfApiSchema } from "./hfApi/embeddingHfApi"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; // parameters passed when generating text interface EmbeddingEndpointParameters { @@ -33,8 +34,8 @@ export const embeddingEndpointSchema = z.discriminatedUnion("type", [ type EmbeddingEndpointTypeOptions = z.infer["type"]; // generator function that takes in type discrimantor value for defining the endpoint and return the endpoint -export type EmbeddingEndpointGenerator = ( - inputs: Extract, { type: T }> +type EmbeddingEndpointGenerator = ( + inputs: Extract, { type: T }> & { model: EmbeddingModel } ) => EmbeddingEndpoint | Promise; // list of all endpoint generators diff --git a/src/lib/server/embeddingEndpoints/hfApi/embeddingHfApi.ts b/src/lib/server/embeddingEndpoints/hfApi/embeddingHfApi.ts index 7f61fb376b2..283a452a2fa 100644 --- a/src/lib/server/embeddingEndpoints/hfApi/embeddingHfApi.ts +++ b/src/lib/server/embeddingEndpoints/hfApi/embeddingHfApi.ts @@ -3,10 +3,10 @@ import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints"; import { chunk } from "$lib/utils/chunk"; import { env } from "$env/dynamic/private"; import { logger } from "$lib/server/logger"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; export const embeddingEndpointHfApiSchema = z.object({ weight: z.number().int().positive().default(1), - model: z.any(), type: z.literal("hfapi"), authorization: z .string() @@ -14,11 +14,16 @@ export const embeddingEndpointHfApiSchema = z.object({ .transform((v) => (!v && env.HF_TOKEN ? "Bearer " + env.HF_TOKEN : v)), // if the header is not set but HF_TOKEN is, use it as the authorization header }); +type EmbeddingEndpointHfApiInput = z.input & { + model: EmbeddingModel; +}; + export async function embeddingEndpointHfApi( - input: z.input + input: EmbeddingEndpointHfApiInput ): Promise { - const { model, authorization } = embeddingEndpointHfApiSchema.parse(input); - const url = "https://api-inference.huggingface.co/models/" + model.id; + const { model } = input; + const { authorization } = embeddingEndpointHfApiSchema.parse(input); + const url = "https://api-inference.huggingface.co/models/" + model.name; return async ({ inputs }) => { const batchesInputs = chunk(inputs, 128); diff --git a/src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts index d1725ffad1a..4b973d66808 100644 --- a/src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts @@ -2,21 +2,25 @@ import { z } from "zod"; import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints"; import { chunk } from "$lib/utils/chunk"; import { env } from "$env/dynamic/private"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; export const embeddingEndpointOpenAIParametersSchema = z.object({ weight: z.number().int().positive().default(1), - model: z.any(), type: z.literal("openai"), url: z.string().url().default("https://api.openai.com/v1/embeddings"), apiKey: z.string().default(env.OPENAI_API_KEY), defaultHeaders: z.record(z.string()).default({}), }); +type EmbeddingEndpointOpenAIInput = z.input & { + model: EmbeddingModel; +}; + export async function embeddingEndpointOpenAI( - input: z.input + input: EmbeddingEndpointOpenAIInput ): Promise { - const { url, model, apiKey, defaultHeaders } = - embeddingEndpointOpenAIParametersSchema.parse(input); + const { model } = input; + const { url, apiKey, defaultHeaders } = embeddingEndpointOpenAIParametersSchema.parse(input); const maxBatchSize = model.maxBatchSize || 100; diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index c999ceba7da..ec9c38ed6a2 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -3,10 +3,10 @@ import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints"; import { chunk } from "$lib/utils/chunk"; import { env } from "$env/dynamic/private"; import { logger } from "$lib/server/logger"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; export const embeddingEndpointTeiParametersSchema = z.object({ weight: z.number().int().positive().default(1), - model: z.any(), type: z.literal("tei"), url: z.string().url(), authorization: z @@ -35,10 +35,15 @@ const getModelInfoByUrl = async (url: string, authorization?: string) => { } }; +type EmbeddingEndpointTeiInput = z.input & { + model: EmbeddingModel; +}; + export async function embeddingEndpointTei( - input: z.input + input: EmbeddingEndpointTeiInput ): Promise { - const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input); + const { model } = input; + const { url, authorization } = embeddingEndpointTeiParametersSchema.parse(input); const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url); const maxBatchSize = Math.min( diff --git a/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts index 6f24ce74715..df73d3340c4 100644 --- a/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts @@ -2,10 +2,10 @@ import { z } from "zod"; import type { EmbeddingEndpoint } from "../embeddingEndpoints"; import type { Tensor, FeatureExtractionPipeline } from "@xenova/transformers"; import { pipeline } from "@xenova/transformers"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; export const embeddingEndpointTransformersJSParametersSchema = z.object({ weight: z.number().int().positive().default(1), - model: z.any(), type: z.literal("transformersjs"), }); @@ -36,10 +36,16 @@ export async function calculateEmbedding(modelName: string, inputs: string[]) { return output.tolist(); } +type EmbeddingEndpointTransformersJSInput = z.input< + typeof embeddingEndpointTransformersJSParametersSchema +> & { + model: EmbeddingModel; +}; + export function embeddingEndpointTransformersJS( - input: z.input + input: EmbeddingEndpointTransformersJSInput ): EmbeddingEndpoint { - const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input); + const { model } = input; return async ({ inputs }) => { return calculateEmbedding(model.name, inputs); diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts index 67ad8fe5b1e..c63d9115e84 100644 --- a/src/lib/server/embeddingModels.ts +++ b/src/lib/server/embeddingModels.ts @@ -5,11 +5,13 @@ import { sum } from "$lib/utils/sum"; import { embeddingEndpoints, embeddingEndpointSchema, - type EmbeddingEndpoint, } from "$lib/server/embeddingEndpoints/embeddingEndpoints"; import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; import JSON5 from "json5"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; +import { collections } from "./database"; +import { ObjectId } from "mongodb"; const modelConfig = z.object({ /** Used as an identifier in DB */ @@ -42,67 +44,77 @@ const rawEmbeddingModelJSON = const embeddingModelsRaw = z.array(modelConfig).parse(JSON5.parse(rawEmbeddingModelJSON)); -const processEmbeddingModel = async (m: z.infer) => ({ - ...m, - id: m.id || m.name, +const embeddingModels = embeddingModelsRaw.map((rawEmbeddingModel) => { + const embeddingModel: EmbeddingModel = { + name: rawEmbeddingModel.name, + description: rawEmbeddingModel.description, + websiteUrl: rawEmbeddingModel.websiteUrl, + modelUrl: rawEmbeddingModel.modelUrl, + chunkCharLength: rawEmbeddingModel.chunkCharLength, + maxBatchSize: rawEmbeddingModel.maxBatchSize, + preQuery: rawEmbeddingModel.preQuery, + prePassage: rawEmbeddingModel.prePassage, + _id: new ObjectId(), + createdAt: new Date(), + updatedAt: new Date(), + endpoints: rawEmbeddingModel.endpoints, + }; + + return embeddingModel; }); -const addEndpoint = (m: Awaited>) => ({ - ...m, - getEndpoint: async (): Promise => { - if (!m.endpoints) { - return embeddingEndpointTransformersJS({ - type: "transformersjs", - weight: 1, - model: m, - }); - } +export const getEmbeddingEndpoint = async (embeddingModel: EmbeddingModel) => { + if (!embeddingModel.endpoints) { + return embeddingEndpointTransformersJS({ + type: "transformersjs", + weight: 1, + model: embeddingModel, + }); + } - const totalWeight = sum(m.endpoints.map((e) => e.weight)); - - let random = Math.random() * totalWeight; - - for (const endpoint of m.endpoints) { - if (random < endpoint.weight) { - const args = { ...endpoint, model: m }; - - switch (args.type) { - case "tei": - return embeddingEndpoints.tei(args); - case "transformersjs": - return embeddingEndpoints.transformersjs(args); - case "openai": - return embeddingEndpoints.openai(args); - case "hfapi": - return embeddingEndpoints.hfapi(args); - default: - throw new Error(`Unknown endpoint type: ${args}`); - } + const totalWeight = sum(embeddingModel.endpoints.map((e) => e.weight)); + + let random = Math.random() * totalWeight; + + for (const endpoint of embeddingModel.endpoints) { + if (random < endpoint.weight) { + const args = { ...endpoint, model: embeddingModel }; + console.log(args.type); + + switch (args.type) { + case "tei": + return embeddingEndpoints.tei(args); + case "transformersjs": + return embeddingEndpoints.transformersjs(args); + case "openai": + return embeddingEndpoints.openai(args); + case "hfapi": + return embeddingEndpoints.hfapi(args); + default: + throw new Error(`Unknown endpoint type: ${args}`); } - - random -= endpoint.weight; } - throw new Error(`Failed to select embedding endpoint`); - }, -}); - -export const embeddingModels = await Promise.all( - embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint)) -); - -export const defaultEmbeddingModel = embeddingModels[0]; + random -= endpoint.weight; + } -const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => { - return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]); + throw new Error(`Failed to select embedding endpoint`); }; -export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => { - return validateEmbeddingModel(_models, "id"); -}; +export const getDefaultEmbeddingModel = async (): Promise => { + if (!embeddingModels[0]) { + throw new Error(`Failed to find default embedding endpoint`); + } + + const defaultModel = await collections.embeddingModels.findOne({ + _id: embeddingModels[0]._id, + }); -export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => { - return validateEmbeddingModel(_models, "name"); + return defaultModel ? defaultModel : embeddingModels[0]; }; -export type EmbeddingBackendModel = typeof defaultEmbeddingModel; +// to mimic current behaivor with creating embedding models from scratch during server start +export async function pupulateEmbeddingModel() { + await collections.embeddingModels.deleteMany({}); + await collections.embeddingModels.insertMany(embeddingModels); +} diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index 6efff6ec138..8e634e18816 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -5,7 +5,6 @@ import { z } from "zod"; import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints"; import { endpointTgi } from "./endpoints/tgi/endpointTgi"; import { sum } from "$lib/utils/sum"; -import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels"; import type { PreTrainedTokenizer } from "@xenova/transformers"; @@ -13,6 +12,7 @@ import JSON5 from "json5"; import { getTokenizer } from "$lib/utils/getTokenizer"; import { logger } from "$lib/server/logger"; import { ToolResultStatus } from "$lib/types/Tool"; +import { collections } from "./database"; type Optional = Pick, K> & Omit; @@ -64,7 +64,7 @@ const modelConfig = z.object({ multimodal: z.boolean().default(false), tools: z.boolean().default(false), unlisted: z.boolean().default(false), - embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(), + embeddingModel: z.string().optional(), }); const modelsRaw = z.array(modelConfig).parse(JSON5.parse(env.MODELS)); @@ -227,6 +227,10 @@ const processModel = async (m: z.infer) => ({ displayName: m.displayName || m.name, preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt, parameters: { ...m.parameters, stop_sequences: m.parameters?.stop }, + embeddingModel: + (await collections.embeddingModels + .findOne({ name: m.embeddingModel }) + .then((embeddingModel) => embeddingModel?.name)) ?? undefined, }); export type ProcessedModel = Awaited> & { diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts index 45141e357b7..48e2e527c6a 100644 --- a/src/lib/server/sentenceSimilarity.ts +++ b/src/lib/server/sentenceSimilarity.ts @@ -1,6 +1,7 @@ import { dot } from "@xenova/transformers"; -import type { EmbeddingBackendModel } from "$lib/server/embeddingModels"; import type { Embedding } from "$lib/server/embeddingEndpoints/embeddingEndpoints"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; +import { getEmbeddingEndpoint } from "./embeddingModels"; // see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34 export function innerProduct(embeddingA: Embedding, embeddingB: Embedding) { @@ -8,7 +9,7 @@ export function innerProduct(embeddingA: Embedding, embeddingB: Embedding) { } export async function getSentenceSimilarity( - embeddingModel: EmbeddingBackendModel, + embeddingModel: EmbeddingModel, query: string, sentences: string[] ): Promise<{ distance: number; embedding: Embedding; idx: number }[]> { @@ -17,7 +18,7 @@ export async function getSentenceSimilarity( ...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`), ]; - const embeddingEndpoint = await embeddingModel.getEndpoint(); + const embeddingEndpoint = await getEmbeddingEndpoint(embeddingModel); const output = await embeddingEndpoint({ inputs }).catch((err) => { throw Error("Failed to generate embeddings for sentence similarity", { cause: err }); }); diff --git a/src/lib/server/websearch/embed/combine.ts b/src/lib/server/websearch/embed/combine.ts index 29b4113c2d8..191f392c828 100644 --- a/src/lib/server/websearch/embed/combine.ts +++ b/src/lib/server/websearch/embed/combine.ts @@ -1,12 +1,12 @@ -import type { EmbeddingBackendModel } from "$lib/server/embeddingModels"; import { getSentenceSimilarity } from "$lib/server/sentenceSimilarity"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; /** * Combines sentences together to reach the maximum character limit of the embedding model * Improves performance considerably when using CPU embedding */ export async function getCombinedSentenceSimilarity( - embeddingModel: EmbeddingBackendModel, + embeddingModel: EmbeddingModel, query: string, sentences: string[] ): ReturnType { diff --git a/src/lib/server/websearch/embed/embed.ts b/src/lib/server/websearch/embed/embed.ts index aba7dee13bc..3d446ebd25a 100644 --- a/src/lib/server/websearch/embed/embed.ts +++ b/src/lib/server/websearch/embed/embed.ts @@ -1,6 +1,6 @@ import { MetricsServer } from "$lib/server/metrics"; +import type { EmbeddingModel } from "$lib/types/EmbeddingModel"; import type { WebSearchScrapedSource, WebSearchUsedSource } from "$lib/types/WebSearch"; -import type { EmbeddingBackendModel } from "../../embeddingModels"; import { getSentenceSimilarity, innerProduct } from "../../sentenceSimilarity"; import { MarkdownElementType, type MarkdownElement } from "../markdown/types"; import { stringifyMarkdownElement } from "../markdown/utils/stringify"; @@ -13,7 +13,7 @@ const SOFT_MAX_CHARS = 8_000; export async function findContextSources( sources: WebSearchScrapedSource[], prompt: string, - embeddingModel: EmbeddingBackendModel + embeddingModel: EmbeddingModel ) { const startTime = Date.now(); diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index 45c251a7d10..03dffeb4bf8 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -1,4 +1,4 @@ -import { defaultEmbeddingModel, embeddingModels } from "$lib/server/embeddingModels"; +import { getDefaultEmbeddingModel } from "$lib/server/embeddingModels"; import type { Conversation } from "$lib/types/Conversation"; import type { Message } from "$lib/types/Message"; @@ -18,6 +18,7 @@ import { } from "./update"; import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators"; import { MetricsServer } from "../metrics"; +import { collections } from "../database"; const MAX_N_PAGES_TO_SCRAPE = 8 as const; const MAX_N_PAGES_TO_EMBED = 5 as const; @@ -35,8 +36,11 @@ export async function* runWebSearch( MetricsServer.getMetrics().webSearch.requestCount.inc(); try { + const defaultEmbeddingModel = await getDefaultEmbeddingModel(); const embeddingModel = - embeddingModels.find((m) => m.id === conv.embeddingModel) ?? defaultEmbeddingModel; + (await collections.embeddingModels.findOne({ + name: conv.embeddingModel, + })) ?? defaultEmbeddingModel; if (!embeddingModel) { throw Error(`Embedding model ${conv.embeddingModel} not available anymore`); } diff --git a/src/lib/types/EmbeddingEndpoint.ts b/src/lib/types/EmbeddingEndpoint.ts new file mode 100644 index 00000000000..2e2f66b439f --- /dev/null +++ b/src/lib/types/EmbeddingEndpoint.ts @@ -0,0 +1,31 @@ +interface EmbeddingEndpointTei { + type: "tei"; + weight: number; + url: string; + authorization?: string; +} + +interface EmbeddingEndpointTransformersjs { + type: "transformersjs"; + weight: number; +} + +interface EmbeddingEndpointOpenai { + type: "openai"; + weight: number; + url: string; + apiKey: string; + defaultHeaders: Record; +} + +interface EmbeddingEndpointHfApi { + type: "hfapi"; + weight: number; + authorization?: string; +} + +export type EmbeddingEndpoint = + | EmbeddingEndpointTei + | EmbeddingEndpointTransformersjs + | EmbeddingEndpointOpenai + | EmbeddingEndpointHfApi; diff --git a/src/lib/types/EmbeddingModel.ts b/src/lib/types/EmbeddingModel.ts new file mode 100644 index 00000000000..0b709c1cb8d --- /dev/null +++ b/src/lib/types/EmbeddingModel.ts @@ -0,0 +1,21 @@ +import type { ObjectId } from "mongodb"; +import type { Timestamps } from "./Timestamps"; +import type { EmbeddingEndpoint } from "./EmbeddingEndpoint"; + +export interface EmbeddingModel extends Timestamps { + _id: ObjectId; + + name: string; + + description?: string; + websiteUrl?: string; + modelUrl?: string; + + chunkCharLength: number; + maxBatchSize?: number; + + preQuery: string; + prePassage: string; + + endpoints: EmbeddingEndpoint[]; +} diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts index af372f51e4d..07696c44078 100644 --- a/src/routes/conversation/+server.ts +++ b/src/routes/conversation/+server.ts @@ -6,7 +6,7 @@ import { base } from "$app/paths"; import { z } from "zod"; import type { Message } from "$lib/types/Message"; import { models, validateModel } from "$lib/server/models"; -import { defaultEmbeddingModel } from "$lib/server/embeddingModels"; +import { getDefaultEmbeddingModel } from "$lib/server/embeddingModels"; import { v4 } from "uuid"; import { authCondition } from "$lib/server/auth"; import { usageLimits } from "$lib/server/usageLimits"; @@ -78,6 +78,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { values.assistantId = conversation.assistantId?.toString(); embeddingModel = conversation.embeddingModel; } + const defaultEmbeddingModel = await getDefaultEmbeddingModel(); embeddingModel ??= model.embeddingModel ?? defaultEmbeddingModel.name; diff --git a/src/routes/login/callback/updateUser.spec.ts b/src/routes/login/callback/updateUser.spec.ts index fefaf8b0f5a..ca80a9b3d6a 100644 --- a/src/routes/login/callback/updateUser.spec.ts +++ b/src/routes/login/callback/updateUser.spec.ts @@ -6,7 +6,7 @@ import { ObjectId } from "mongodb"; import { DEFAULT_SETTINGS } from "$lib/types/Settings"; import { defaultModel } from "$lib/server/models"; import { findUser } from "$lib/server/auth"; -import { defaultEmbeddingModel } from "$lib/server/embeddingModels"; +import { getDefaultEmbeddingModel } from "$lib/server/embeddingModels"; const userData = { preferred_username: "new-username", @@ -41,13 +41,15 @@ const insertRandomUser = async () => { }; const insertRandomConversations = async (count: number) => { + const defaultEmbeddingModel = await getDefaultEmbeddingModel(); + const res = await collections.conversations.insertMany( new Array(count).fill(0).map(() => ({ _id: new ObjectId(), title: "random title", messages: [], model: defaultModel.id, - embeddingModel: defaultEmbeddingModel.id, + embeddingModel: defaultEmbeddingModel.name, createdAt: new Date(), updatedAt: new Date(), sessionId: locals.sessionId,