Skip to content

Commit

Permalink
integrating feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
shaper committed Dec 25, 2024
1 parent cdca1e0 commit 0b57fe1
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 208 deletions.
56 changes: 26 additions & 30 deletions examples/ai-core/src/e2e/deepinfra.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,34 @@ import { createFeatureTestSuite } from './feature-test-suite';

createFeatureTestSuite({
name: 'DeepInfra',
createChatModelFn: provider.chatModel,
createCompletionModelFn: provider.completionModel,
createEmbeddingModelFn: provider.textEmbeddingModel,
models: {
chat: [
'google/codegemma-7b-it', // no tools, objects, or images
'google/gemma-2-9b-it', // no tools, objects, or images
'meta-llama/Llama-3.2-11B-Vision-Instruct', // no tools, *does* support images
'meta-llama/Llama-3.2-90B-Vision-Instruct', // no tools, *does* support images
'meta-llama/Llama-3.3-70B-Instruct-Turbo', // no image input
'meta-llama/Llama-3.3-70B-Instruct', // no image input
'meta-llama/Meta-Llama-3.1-405B-Instruct', // no image input
'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', // no image input
'meta-llama/Meta-Llama-3.1-70B-Instruct', // no image input
'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', // no *streaming* tools, no image input
'meta-llama/Meta-Llama-3.1-8B-Instruct', // no image input
'microsoft/WizardLM-2-8x22B', // no objects, tools, or images
'mistralai/Mixtral-8x7B-Instruct-v0.1', // no *streaming* tools, no image input
'nvidia/Llama-3.1-Nemotron-70B-Instruct', // no images
'Qwen/Qwen2-7B-Instruct', // no tools, no image input
'Qwen/Qwen2.5-72B-Instruct', // no images
'Qwen/Qwen2.5-Coder-32B-Instruct', // no tool calls, no image input
'Qwen/QwQ-32B-Preview', // no tools, no image input
invalidModel: provider.chatModel('no-such-model'),
languageModels: [
provider.chatModel('google/codegemma-7b-it'), // no tools, objects, or images
provider.chatModel('google/gemma-2-9b-it'), // no tools, objects, or images
provider.chatModel('meta-llama/Llama-3.2-11B-Vision-Instruct'), // no tools, *does* support images
provider.chatModel('meta-llama/Llama-3.2-90B-Vision-Instruct'), // no tools, *does* support images
provider.chatModel('meta-llama/Llama-3.3-70B-Instruct-Turbo'), // no image input
provider.chatModel('meta-llama/Llama-3.3-70B-Instruct'), // no image input
provider.chatModel('meta-llama/Meta-Llama-3.1-405B-Instruct'), // no image input
provider.chatModel('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'), // no image input
provider.chatModel('meta-llama/Meta-Llama-3.1-70B-Instruct'), // no image input
provider.chatModel('meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'), // no *streaming* tools, no image input
provider.chatModel('meta-llama/Meta-Llama-3.1-8B-Instruct'), // no image input
provider.chatModel('microsoft/WizardLM-2-8x22B'), // no objects, tools, or images
provider.chatModel('mistralai/Mixtral-8x7B-Instruct-v0.1'), // no *streaming* tools, no image input
provider.chatModel('nvidia/Llama-3.1-Nemotron-70B-Instruct'), // no images
provider.chatModel('Qwen/Qwen2-7B-Instruct'), // no tools, no image input
provider.chatModel('Qwen/Qwen2.5-72B-Instruct'), // no images
provider.chatModel('Qwen/Qwen2.5-Coder-32B-Instruct'), // no tool calls, no image input
provider.chatModel('Qwen/QwQ-32B-Preview'), // no tools, no image input
provider.completionModel('meta-llama/Meta-Llama-3.1-8B-Instruct'),
provider.completionModel('Qwen/Qwen2-7B-Instruct'),
],
completion: [
'meta-llama/Meta-Llama-3.1-8B-Instruct',
'Qwen/Qwen2-7B-Instruct',
],
embedding: [
'BAAI/bge-base-en-v1.5',
'intfloat/e5-base-v2',
'sentence-transformers/all-mpnet-base-v2',
embeddingModels: [
provider.textEmbeddingModel('BAAI/bge-base-en-v1.5'),
provider.textEmbeddingModel('intfloat/e5-base-v2'),
provider.textEmbeddingModel('sentence-transformers/all-mpnet-base-v2'),
],
},
timeout: 10000,
Expand Down
191 changes: 57 additions & 134 deletions examples/ai-core/src/e2e/feature-test-suite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,17 @@ import {
APICallError,
} from 'ai';
import fs from 'fs';
import { describe, it, expect, vi } from 'vitest';
import type { LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider';
import { describe, expect, it, vi } from 'vitest';
import type { EmbeddingModelV1, LanguageModelV1 } from '@ai-sdk/provider';

export interface ModelVariants {
chat?: string[];
completion?: string[];
embedding?: string[];
invalidModel?: LanguageModelV1;
languageModels?: LanguageModelV1[];
embeddingModels?: EmbeddingModelV1<string>[];
}

export interface TestSuiteOptions {
name: string;
createChatModelFn: (modelId: string) => LanguageModelV1;
createCompletionModelFn: (modelId: string) => LanguageModelV1;
createEmbeddingModelFn: (modelId: string) => EmbeddingModelV1<string>;
models: ModelVariants;
timeout?: number;
customAssertions?: {
Expand All @@ -31,11 +28,16 @@ export interface TestSuiteOptions {
};
}

const createModelObjects = <T extends { modelId: string }>(
models: T[] | undefined,
) =>
models?.map(model => ({
modelId: model.modelId,
model,
})) || [];

export function createFeatureTestSuite({
name,
createChatModelFn,
createCompletionModelFn,
createEmbeddingModelFn,
models,
timeout = 10000,
customAssertions = { skipUsage: false },
Expand All @@ -46,14 +48,13 @@ export function createFeatureTestSuite({
((error: APICallError) => {
throw new Error('errorValidator not implemented');
});

describe(`${name} Feature Test Suite`, () => {
vi.setConfig({ testTimeout: timeout });

// Chat Model Tests
if (models.chat?.length) {
describe.each(models.chat)('Chat Model: %s', modelId => {
const model = createChatModelFn(modelId);

describe.each(createModelObjects(models.languageModels))(
'Language Model: $modelId',
({ model }) => {
it('should generate text', async () => {
const result = await generateText({
model,
Expand Down Expand Up @@ -292,139 +293,26 @@ export function createFeatureTestSuite({
expect((await result.usage)?.totalTokens).toBeGreaterThan(0);
}
});
});
}

describe('Chat Model Error Handling:', () => {
it('should throw error on generate text attempt with invalid model ID', async () => {
const invalidModel = createChatModelFn('no-such-model');

try {
await generateText({
model: invalidModel,
prompt: 'This should fail',
});
// If we reach here, the test should fail
expect(true).toBe(false); // Force test to fail if no error is thrown
} catch (error) {
expect(error).toBeInstanceOf(APICallError);
errorValidator(error as APICallError);
}
});

it('should throw error on stream text attempt with invalid model ID', async () => {
const invalidModel = createChatModelFn('no-such-model');

try {
const result = streamText({
model: invalidModel,
prompt: 'This should fail',
});

// Try to consume the stream to trigger the error
for await (const _ of result.textStream) {
// Do nothing with the chunks
}

// If we reach here, the test should fail
expect(true).toBe(false); // Force test to fail if no error is thrown
} catch (error) {
expect(error).toBeInstanceOf(APICallError);
errorValidator(error as APICallError);
}
});
});

// Embedding Model Tests
if (models.embedding?.length) {
describe.each(models.embedding)('Embedding Model: %s', modelId => {
const model = createEmbeddingModelFn(modelId);

it('should generate single embedding', async () => {
const result = await embed({
model,
value: 'This is a test sentence for embedding.',
});

expect(Array.isArray(result.embedding)).toBe(true);
expect(result.embedding.length).toBeGreaterThan(0);
if (!customAssertions.skipUsage) {
expect(result.usage?.tokens).toBeGreaterThan(0);
}
});
},
);

it('should generate multiple embeddings', async () => {
const result = await embedMany({
model,
values: [
'First test sentence.',
'Second test sentence.',
'Third test sentence.',
],
});
if (models.invalidModel) {
describe('Chat Model Error Handling:', () => {
const invalidModel = models.invalidModel!;

expect(Array.isArray(result.embeddings)).toBe(true);
expect(result.embeddings.length).toBe(3);
if (!customAssertions.skipUsage) {
expect(result.usage?.tokens).toBeGreaterThan(0);
}
});
});
}

if (models.completion?.length) {
describe.each(models.completion)('Completion Model: %s', modelId => {
const model = createCompletionModelFn(modelId);

it('should generate text', async () => {
const result = await generateText({
model,
prompt: 'Complete this code: function fibonacci(n) {',
});

expect(result.text).toBeTruthy();
expect(result.usage?.totalTokens).toBeGreaterThan(0);
});

it('should stream text', async () => {
const result = streamText({
model,
prompt: 'Write a Python function that sorts a list:',
});

const chunks: string[] = [];
for await (const chunk of result.textStream) {
chunks.push(chunk);
}

expect(chunks.length).toBeGreaterThan(0);
if (!customAssertions.skipUsage) {
expect((await result.usage)?.totalTokens).toBeGreaterThan(0);
}
});
});

// New separate error handling describe block
describe('Completion Model Error Handling:', () => {
it('should throw error on generate text attempt with invalid model ID', async () => {
const invalidModel = createCompletionModelFn('no-such-model');

try {
await generateText({
model: invalidModel,
prompt: 'This should fail',
});
// If we reach here, the test should fail
expect(true).toBe(false); // Force test to fail if no error is thrown
} catch (error) {
expect(error).toBeInstanceOf(APICallError);
errorValidator(error as APICallError);
}
});

it('should throw error on stream text attempt with invalid model ID', async () => {
const invalidModel = createCompletionModelFn('no-such-model');

try {
const result = streamText({
model: invalidModel,
Expand All @@ -444,6 +332,41 @@ export function createFeatureTestSuite({
}
});
});

describe.each(createModelObjects(models.embeddingModels))(
'Embedding Model: $modelId',
({ model }) => {
it('should generate single embedding', async () => {
const result = await embed({
model,
value: 'This is a test sentence for embedding.',
});

expect(Array.isArray(result.embedding)).toBe(true);
expect(result.embedding.length).toBeGreaterThan(0);
if (!customAssertions.skipUsage) {
expect(result.usage?.tokens).toBeGreaterThan(0);
}
});

it('should generate multiple embeddings', async () => {
const result = await embedMany({
model,
values: [
'First test sentence.',
'Second test sentence.',
'Third test sentence.',
],
});

expect(Array.isArray(result.embeddings)).toBe(true);
expect(result.embeddings.length).toBe(3);
if (!customAssertions.skipUsage) {
expect(result.usage?.tokens).toBeGreaterThan(0);
}
});
},
);
}
});
};
Expand Down
24 changes: 12 additions & 12 deletions examples/ai-core/src/e2e/fireworks.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ import { createFeatureTestSuite } from './feature-test-suite';

createFeatureTestSuite({
name: 'Fireworks',
createChatModelFn: provider.chatModel,
createCompletionModelFn: provider.completionModel,
createEmbeddingModelFn: provider.textEmbeddingModel,
models: {
chat: [
'accounts/fireworks/models/firefunction-v2',
'accounts/fireworks/models/llama-v3p3-70b-instruct',
'accounts/fireworks/models/mixtral-8x7b-instruct',
'accounts/fireworks/models/qwen2p5-72b-instruct',
invalidModel: provider.chatModel('no-such-model'),
languageModels: [
provider.chatModel('accounts/fireworks/models/firefunction-v2'),
provider.chatModel('accounts/fireworks/models/llama-v3p3-70b-instruct'),
provider.chatModel('accounts/fireworks/models/mixtral-8x7b-instruct'),
provider.chatModel('accounts/fireworks/models/qwen2p5-72b-instruct'),
provider.completionModel(
'accounts/fireworks/models/llama-v3-8b-instruct',
),
provider.completionModel('accounts/fireworks/models/llama-v2-34b-code'),
],
completion: [
'accounts/fireworks/models/llama-v3-8b-instruct',
'accounts/fireworks/models/llama-v2-34b-code',
embeddingModels: [
provider.textEmbeddingModel('nomic-ai/nomic-embed-text-v1.5'),
],
embedding: ['nomic-ai/nomic-embed-text-v1.5'],
},
timeout: 10000,
customAssertions: {
Expand Down
28 changes: 13 additions & 15 deletions examples/ai-core/src/e2e/togetherai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,21 @@ import { createFeatureTestSuite } from './feature-test-suite';

createFeatureTestSuite({
name: 'TogetherAI',
createChatModelFn: provider.chatModel,
createCompletionModelFn: provider.completionModel,
createEmbeddingModelFn: provider.textEmbeddingModel,
models: {
chat: [
'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo',
'mistralai/Mistral-7B-Instruct-v0.1',
'google/gemma-2b-it',
'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo',
'mistralai/Mixtral-8x7B-Instruct-v0.1',
'Qwen/Qwen2.5-72B-Instruct-Turbo',
'databricks/dbrx-instruct',
invalidModel: provider.chatModel('no-such-model'),
languageModels: [
provider.chatModel('meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'),
provider.chatModel('mistralai/Mistral-7B-Instruct-v0.1'),
provider.chatModel('google/gemma-2b-it'),
provider.chatModel('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'),
provider.chatModel('mistralai/Mixtral-8x7B-Instruct-v0.1'),
provider.chatModel('Qwen/Qwen2.5-72B-Instruct-Turbo'),
provider.chatModel('databricks/dbrx-instruct'),
provider.completionModel('Qwen/Qwen2.5-Coder-32B-Instruct'),
],
completion: ['Qwen/Qwen2.5-Coder-32B-Instruct'],
embedding: [
'togethercomputer/m2-bert-80M-8k-retrieval',
'BAAI/bge-base-en-v1.5',
embeddingModels: [
provider.textEmbeddingModel('togethercomputer/m2-bert-80M-8k-retrieval'),
provider.textEmbeddingModel('BAAI/bge-base-en-v1.5'),
],
},
timeout: 10000,
Expand Down
Loading

0 comments on commit 0b57fe1

Please sign in to comment.