Skip to content

Commit

Permalink
feat(embeddingModel): add embedding model into mongodb
Browse files Browse the repository at this point in the history
  • Loading branch information
neven4 committed Jul 25, 2024
1 parent 7692f71 commit f09dcd0
Show file tree
Hide file tree
Showing 18 changed files with 194 additions and 86 deletions.
7 changes: 4 additions & 3 deletions scripts/populate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import type { User } from "../src/lib/types/User";
import type { Assistant } from "../src/lib/types/Assistant";
import type { Conversation } from "../src/lib/types/Conversation";
import type { Settings } from "../src/lib/types/Settings";
import { defaultEmbeddingModel } from "../src/lib/server/embeddingModels.ts";
import { getDefaultEmbeddingModel } from "../src/lib/server/embeddingModels.ts";
import { Message } from "../src/lib/types/Message.ts";

import { addChildren } from "../src/lib/utils/tree/addChildren.ts";
Expand Down Expand Up @@ -146,6 +146,7 @@ async function seed() {
updatedAt: faker.date.recent({ days: 30 }),
customPrompts: {},
assistants: [],
disableStream: false,
};
await collections.settings.updateOne(
{ userId: user._id },
Expand Down Expand Up @@ -214,7 +215,7 @@ async function seed() {
: faker.helpers.maybe(() => faker.hacker.phrase(), { probability: 0.5 })) ?? "";

const messages = await generateMessages(preprompt);

const defaultEmbeddingModel = await getDefaultEmbeddingModel();
const conv = {
_id: new ObjectId(),
userId: user._id,
Expand All @@ -224,7 +225,7 @@ async function seed() {
updatedAt: faker.date.recent({ days: 145 }),
model: faker.helpers.arrayElement(modelIds),
title: faker.internet.emoji() + " " + faker.hacker.phrase(),
embeddingModel: defaultEmbeddingModel.id,
embeddingModel: defaultEmbeddingModel.name,
messages,
rootMessageId: messages[0].id,
} satisfies Conversation;
Expand Down
4 changes: 4 additions & 0 deletions src/hooks.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { initExitHandler } from "$lib/server/exitHandler";
import { ObjectId } from "mongodb";
import { refreshAssistantsCounts } from "$lib/jobs/refresh-assistants-counts";
import { refreshConversationStats } from "$lib/jobs/refresh-conversation-stats";
import { pupulateEmbeddingModel } from "$lib/server/embeddingModels";

// TODO: move this code on a started server hook, instead of using a "building" flag
if (!building) {
Expand All @@ -25,6 +26,9 @@ if (!building) {
if (env.ENABLE_ASSISTANTS) {
refreshAssistantsCounts();
}

await pupulateEmbeddingModel();

refreshConversationStats();

// Init metrics server
Expand Down
6 changes: 6 additions & 0 deletions src/lib/server/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import type { AssistantStats } from "$lib/types/AssistantStats";
import { logger } from "$lib/server/logger";
import { building } from "$app/environment";
import { onExit } from "./exitHandler";
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";

export const CONVERSATION_STATS_COLLECTION = "conversations.stats";

Expand Down Expand Up @@ -83,6 +84,7 @@ export class Database {
const bucket = new GridFSBucket(db, { bucketName: "files" });
const migrationResults = db.collection<MigrationResult>("migrationResults");
const semaphores = db.collection<Semaphore>("semaphores");
const embeddingModels = db.collection<EmbeddingModel>("embeddingModels");

return {
conversations,
Expand All @@ -99,6 +101,7 @@ export class Database {
bucket,
migrationResults,
semaphores,
embeddingModels,
};
}

Expand All @@ -120,6 +123,7 @@ export class Database {
sessions,
messageEvents,
semaphores,
embeddingModels,
} = this.getCollections();

conversations
Expand Down Expand Up @@ -209,6 +213,8 @@ export class Database {
semaphores
.createIndex({ createdAt: 1 }, { expireAfterSeconds: 60 })
.catch((e) => logger.error(e));

embeddingModels.createIndex({ name: 1 }, { unique: true }).catch((e) => logger.error(e));
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/lib/server/embeddingEndpoints/embeddingEndpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
embeddingEndpointOpenAIParametersSchema,
} from "./openai/embeddingEndpoints";
import { embeddingEndpointHfApi, embeddingEndpointHfApiSchema } from "./hfApi/embeddingHfApi";
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";

// parameters passed when generating text
interface EmbeddingEndpointParameters {
Expand All @@ -33,8 +34,8 @@ export const embeddingEndpointSchema = z.discriminatedUnion("type", [
type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];

// generator function that takes in type discrimantor value for defining the endpoint and return the endpoint
export type EmbeddingEndpointGenerator<T extends EmbeddingEndpointTypeOptions> = (
inputs: Extract<z.infer<typeof embeddingEndpointSchema>, { type: T }>
type EmbeddingEndpointGenerator<T extends EmbeddingEndpointTypeOptions> = (
inputs: Extract<z.infer<typeof embeddingEndpointSchema>, { type: T }> & { model: EmbeddingModel }
) => EmbeddingEndpoint | Promise<EmbeddingEndpoint>;

// list of all endpoint generators
Expand Down
13 changes: 9 additions & 4 deletions src/lib/server/embeddingEndpoints/hfApi/embeddingHfApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@ import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
import { chunk } from "$lib/utils/chunk";
import { env } from "$env/dynamic/private";
import { logger } from "$lib/server/logger";
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";

export const embeddingEndpointHfApiSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("hfapi"),
authorization: z
.string()
.optional()
.transform((v) => (!v && env.HF_TOKEN ? "Bearer " + env.HF_TOKEN : v)), // if the header is not set but HF_TOKEN is, use it as the authorization header
});

type EmbeddingEndpointHfApiInput = z.input<typeof embeddingEndpointHfApiSchema> & {
model: EmbeddingModel;
};

export async function embeddingEndpointHfApi(
input: z.input<typeof embeddingEndpointHfApiSchema>
input: EmbeddingEndpointHfApiInput
): Promise<EmbeddingEndpoint> {
const { model, authorization } = embeddingEndpointHfApiSchema.parse(input);
const url = "https://api-inference.huggingface.co/models/" + model.id;
const { model } = input;
const { authorization } = embeddingEndpointHfApiSchema.parse(input);
const url = "https://api-inference.huggingface.co/models/" + model.name;

return async ({ inputs }) => {
const batchesInputs = chunk(inputs, 128);
Expand Down
12 changes: 8 additions & 4 deletions src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@ import { z } from "zod";
import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
import { chunk } from "$lib/utils/chunk";
import { env } from "$env/dynamic/private";
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";

export const embeddingEndpointOpenAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("openai"),
url: z.string().url().default("https://api.openai.com/v1/embeddings"),
apiKey: z.string().default(env.OPENAI_API_KEY),
defaultHeaders: z.record(z.string()).default({}),
});

type EmbeddingEndpointOpenAIInput = z.input<typeof embeddingEndpointOpenAIParametersSchema> & {
model: EmbeddingModel;
};

export async function embeddingEndpointOpenAI(
input: z.input<typeof embeddingEndpointOpenAIParametersSchema>
input: EmbeddingEndpointOpenAIInput
): Promise<EmbeddingEndpoint> {
const { url, model, apiKey, defaultHeaders } =
embeddingEndpointOpenAIParametersSchema.parse(input);
const { model } = input;
const { url, apiKey, defaultHeaders } = embeddingEndpointOpenAIParametersSchema.parse(input);

const maxBatchSize = model.maxBatchSize || 100;

Expand Down
11 changes: 8 additions & 3 deletions src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
import { chunk } from "$lib/utils/chunk";
import { env } from "$env/dynamic/private";
import { logger } from "$lib/server/logger";
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";

export const embeddingEndpointTeiParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("tei"),
url: z.string().url(),
authorization: z
Expand Down Expand Up @@ -35,10 +35,15 @@ const getModelInfoByUrl = async (url: string, authorization?: string) => {
}
};

type EmbeddingEndpointTeiInput = z.input<typeof embeddingEndpointTeiParametersSchema> & {
model: EmbeddingModel;
};

export async function embeddingEndpointTei(
input: z.input<typeof embeddingEndpointTeiParametersSchema>
input: EmbeddingEndpointTeiInput
): Promise<EmbeddingEndpoint> {
const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);
const { model } = input;
const { url, authorization } = embeddingEndpointTeiParametersSchema.parse(input);

const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
const maxBatchSize = Math.min(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import { z } from "zod";
import type { EmbeddingEndpoint } from "../embeddingEndpoints";
import type { Tensor, FeatureExtractionPipeline } from "@xenova/transformers";
import { pipeline } from "@xenova/transformers";
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";

export const embeddingEndpointTransformersJSParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("transformersjs"),
});

Expand Down Expand Up @@ -36,10 +36,16 @@ export async function calculateEmbedding(modelName: string, inputs: string[]) {
return output.tolist();
}

type EmbeddingEndpointTransformersJSInput = z.input<
typeof embeddingEndpointTransformersJSParametersSchema
> & {
model: EmbeddingModel;
};

export function embeddingEndpointTransformersJS(
input: z.input<typeof embeddingEndpointTransformersJSParametersSchema>
input: EmbeddingEndpointTransformersJSInput
): EmbeddingEndpoint {
const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input);
const { model } = input;

return async ({ inputs }) => {
return calculateEmbedding(model.name, inputs);
Expand Down
118 changes: 65 additions & 53 deletions src/lib/server/embeddingModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import { sum } from "$lib/utils/sum";
import {
embeddingEndpoints,
embeddingEndpointSchema,
type EmbeddingEndpoint,
} from "$lib/server/embeddingEndpoints/embeddingEndpoints";
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";

import JSON5 from "json5";
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";
import { collections } from "./database";
import { ObjectId } from "mongodb";

const modelConfig = z.object({
/** Used as an identifier in DB */
Expand Down Expand Up @@ -42,67 +44,77 @@ const rawEmbeddingModelJSON =

const embeddingModelsRaw = z.array(modelConfig).parse(JSON5.parse(rawEmbeddingModelJSON));

const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
id: m.id || m.name,
const embeddingModels = embeddingModelsRaw.map((rawEmbeddingModel) => {
const embeddingModel: EmbeddingModel = {
name: rawEmbeddingModel.name,
description: rawEmbeddingModel.description,
websiteUrl: rawEmbeddingModel.websiteUrl,
modelUrl: rawEmbeddingModel.modelUrl,
chunkCharLength: rawEmbeddingModel.chunkCharLength,
maxBatchSize: rawEmbeddingModel.maxBatchSize,
preQuery: rawEmbeddingModel.preQuery,
prePassage: rawEmbeddingModel.prePassage,
_id: new ObjectId(),
createdAt: new Date(),
updatedAt: new Date(),
endpoints: rawEmbeddingModel.endpoints,
};

return embeddingModel;
});

const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
...m,
getEndpoint: async (): Promise<EmbeddingEndpoint> => {
if (!m.endpoints) {
return embeddingEndpointTransformersJS({
type: "transformersjs",
weight: 1,
model: m,
});
}
export const getEmbeddingEndpoint = async (embeddingModel: EmbeddingModel) => {
if (!embeddingModel.endpoints) {
return embeddingEndpointTransformersJS({
type: "transformersjs",
weight: 1,
model: embeddingModel,
});
}

const totalWeight = sum(m.endpoints.map((e) => e.weight));

let random = Math.random() * totalWeight;

for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };

switch (args.type) {
case "tei":
return embeddingEndpoints.tei(args);
case "transformersjs":
return embeddingEndpoints.transformersjs(args);
case "openai":
return embeddingEndpoints.openai(args);
case "hfapi":
return embeddingEndpoints.hfapi(args);
default:
throw new Error(`Unknown endpoint type: ${args}`);
}
const totalWeight = sum(embeddingModel.endpoints.map((e) => e.weight));

let random = Math.random() * totalWeight;

for (const endpoint of embeddingModel.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: embeddingModel };
console.log(args.type);

switch (args.type) {
case "tei":
return embeddingEndpoints.tei(args);
case "transformersjs":
return embeddingEndpoints.transformersjs(args);
case "openai":
return embeddingEndpoints.openai(args);
case "hfapi":
return embeddingEndpoints.hfapi(args);
default:
throw new Error(`Unknown endpoint type: ${args}`);
}

random -= endpoint.weight;
}

throw new Error(`Failed to select embedding endpoint`);
},
});

export const embeddingModels = await Promise.all(
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
);

export const defaultEmbeddingModel = embeddingModels[0];
random -= endpoint.weight;
}

const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
throw new Error(`Failed to select embedding endpoint`);
};

export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
return validateEmbeddingModel(_models, "id");
};
export const getDefaultEmbeddingModel = async (): Promise<EmbeddingModel> => {
if (!embeddingModels[0]) {
throw new Error(`Failed to find default embedding endpoint`);
}

const defaultModel = await collections.embeddingModels.findOne({
_id: embeddingModels[0]._id,
});

export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
return validateEmbeddingModel(_models, "name");
return defaultModel ? defaultModel : embeddingModels[0];
};

export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
// to mimic current behaivor with creating embedding models from scratch during server start
export async function pupulateEmbeddingModel() {
await collections.embeddingModels.deleteMany({});
await collections.embeddingModels.insertMany(embeddingModels);
}
Loading

0 comments on commit f09dcd0

Please sign in to comment.