Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added watch() to Drizzle and Kysely integrations #414

Merged
merged 7 commits into from
Dec 2, 2024
Merged
5 changes: 5 additions & 0 deletions .changeset/calm-baboons-worry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@powersync/drizzle-driver': minor
---

Added `watch()` function to support watched queries. This function invokes `execute()` on the Drizzle query which improves support for complex queries such as those which are relational.
63 changes: 60 additions & 3 deletions packages/drizzle-driver/src/sqlite/db.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import { AbstractPowerSyncDatabase, QueryResult } from '@powersync/common';
import {
AbstractPowerSyncDatabase,
QueryResult,
runOnSchemaChange,
SQLWatchOptions,
WatchHandler
} from '@powersync/common';
import { Query } from 'drizzle-orm';
import { DefaultLogger } from 'drizzle-orm/logger';
import {
createTableRelationsHelpers,
Expand All @@ -13,6 +20,8 @@ import { SQLiteAsyncDialect } from 'drizzle-orm/sqlite-core/dialect';
import type { DrizzleConfig } from 'drizzle-orm/utils';
import { PowerSyncSQLiteSession, PowerSyncSQLiteTransactionConfig } from './sqlite-session';

type WatchQuery = { toSQL(): Query; execute(): Promise<any> };

export interface PowerSyncSQLiteDatabase<TSchema extends Record<string, unknown> = Record<string, never>>
extends BaseSQLiteDatabase<'async', QueryResult, TSchema> {
transaction<T>(
Expand All @@ -21,13 +30,15 @@ export interface PowerSyncSQLiteDatabase<TSchema extends Record<string, unknown>
) => Promise<T>,
config?: PowerSyncSQLiteTransactionConfig
): Promise<T>;

watch(query: WatchQuery, handler?: WatchHandler, options?: SQLWatchOptions): void;
}

export function wrapPowerSyncWithDrizzle<TSchema extends Record<string, unknown> = Record<string, never>>(
db: AbstractPowerSyncDatabase,
config: DrizzleConfig<TSchema> = {}
): PowerSyncSQLiteDatabase<TSchema> {
const dialect = new SQLiteAsyncDialect({casing: config.casing});
const dialect = new SQLiteAsyncDialect({ casing: config.casing });
let logger;
if (config.logger === true) {
logger = new DefaultLogger();
Expand All @@ -48,5 +59,51 @@ export function wrapPowerSyncWithDrizzle<TSchema extends Record<string, unknown>
const session = new PowerSyncSQLiteSession(db, dialect, schema, {
logger
});
return new BaseSQLiteDatabase('async', dialect, session, schema) as PowerSyncSQLiteDatabase<TSchema>;

const watch = (query: WatchQuery, handler?: WatchHandler, options?: SQLWatchOptions): void => {
const { onResult, onError = (e: Error) => {} } = handler ?? {};
if (!onResult) {
throw new Error('onResult is required');
}

const watchQuery = async (abortSignal: AbortSignal) => {
try {
const toSql = query.toSQL();
const resolvedTables = await db.resolveTables(toSql.sql, toSql.params, options);

// Fetch initial data
const result = await query.execute();
onResult(result);

db.onChangeWithCallback(
{
onChange: async () => {
try {
const result = await query.execute();
onResult(result);
} catch (error: any) {
onError(error);
}
},
onError
},
{
...(options ?? {}),
tables: resolvedTables,
// Override the abort signal since we intercept it
signal: abortSignal
}
);
} catch (error: any) {
onError(error);
}
};

runOnSchemaChange(watchQuery, db, options);
};

const baseDatabase = new BaseSQLiteDatabase('async', dialect, session, schema) as PowerSyncSQLiteDatabase<TSchema>;
return Object.assign(baseDatabase, {
watch: (query: WatchQuery, handler?: WatchHandler, options?: SQLWatchOptions) => watch(query, handler, options)
});
}
283 changes: 283 additions & 0 deletions packages/drizzle-driver/tests/sqlite/watch.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
import { AbstractPowerSyncDatabase, column, Schema, Table } from '@powersync/common';
import { PowerSyncDatabase } from '@powersync/web';
import { count, eq, sql } from 'drizzle-orm';
import { integer, sqliteTable, text, uniqueIndex } from 'drizzle-orm/sqlite-core';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import * as SUT from '../../src/sqlite/db';

vi.useRealTimers();

const assetsPs = new Table(
{
created_at: column.text,
make: column.text,
model: column.text,
serial_number: column.text,
quantity: column.integer,
user_id: column.text,
customer_id: column.text,
description: column.text
},
{ indexes: { makemodel: ['make, model'] } }
);

const customersPs = new Table({
name: column.text,
email: column.text
});

const PsSchema = new Schema({ assets: assetsPs, customers: customersPs });

const assets = sqliteTable(
'assets',
{
id: text('id'),
created_at: text('created_at'),
make: text('make'),
model: text('model'),
serial_number: text('serial_number'),
quantity: integer('quantity'),
user_id: text('user_id'),
customer_id: text('customer_id'),
description: text('description')
},
(table) => ({
makemodelIndex: uniqueIndex('makemodel').on(table.make, table.model)
})
);

const customers = sqliteTable('customers', {
id: text('id'),
name: text('name'),
email: text('email')
});

const DrizzleSchema = { assets, customers };

/**
* There seems to be an issue with Vitest browser mode's setTimeout and
* fake timer functionality.
* e.g. calling:
* await new Promise<void>((resolve) => setTimeout(resolve, 10));
* waits for 1 second instead of 10ms.
* Setting this to 1 second as a work around.
*/
const throttleDuration = 1000;

describe('Watch Tests', () => {
let powerSyncDb: AbstractPowerSyncDatabase;
let db: SUT.PowerSyncSQLiteDatabase<typeof DrizzleSchema>;

beforeEach(async () => {
powerSyncDb = new PowerSyncDatabase({
database: {
dbFilename: 'test.db'
},
schema: PsSchema
});
db = SUT.wrapPowerSyncWithDrizzle(powerSyncDb, { schema: DrizzleSchema, logger: { logQuery: () => {} } });

await powerSyncDb.init();
});

afterEach(async () => {
await powerSyncDb.disconnectAndClear();
});

it('watch outside throttle limits', async () => {
const abortController = new AbortController();

const updatesCount = 2;
let receivedUpdatesCount = 0;

/**
* Promise which resolves once we received the same amount of update
* notifications as there are inserts.
*/
const receivedUpdates = new Promise<void>((resolve) => {
const onUpdate = () => {
receivedUpdatesCount++;

if (receivedUpdatesCount == updatesCount) {
abortController.abort();
resolve();
}
};

const query = db
.select({ count: count() })
.from(assets)
.innerJoin(customers, eq(customers.id, assets.customer_id));

db.watch(query, { onResult: onUpdate }, { signal: abortController.signal, throttleMs: throttleDuration });
});

for (let updateCount = 0; updateCount < updatesCount; updateCount++) {
await db
.insert(assets)
.values({
id: sql`uuid()`,
make: 'test',
customer_id: sql`uuid()`
})
.execute();

// Wait the throttle duration, ensuring a watch update for each insert
await new Promise<void>((resolve) => setTimeout(resolve, throttleDuration));
}

await receivedUpdates;
expect(receivedUpdatesCount).equals(updatesCount);
});

it('watch inside throttle limits', async () => {
const abortController = new AbortController();

const updatesCount = 5;
let receivedUpdatesCount = 0;

const onUpdate = () => {
receivedUpdatesCount++;
};
const query = db.select({ count: count() }).from(assets).innerJoin(customers, eq(customers.id, assets.customer_id));
db.watch(query, { onResult: onUpdate }, { signal: abortController.signal, throttleMs: throttleDuration });

// Create the inserts as fast as possible
for (let updateCount = 0; updateCount < updatesCount; updateCount++) {
await db
.insert(assets)
.values({
id: sql`uuid()`,
make: 'test',
customer_id: sql`uuid()`
})
.execute();
}

await new Promise<void>((resolve) => setTimeout(resolve, throttleDuration * 2));
abortController.abort();

// There should be one initial result plus one throttled result
expect(receivedUpdatesCount).equals(2);
});

it('should only watch tables inside query', async () => {
const assetsAbortController = new AbortController();

let receivedAssetsUpdatesCount = 0;
const onWatchAssets = () => {
receivedAssetsUpdatesCount++;
};

const queryAssets = db.select({ count: count() }).from(assets);

db.watch(
queryAssets,
{ onResult: onWatchAssets },
{
signal: assetsAbortController.signal
}
);

const customersAbortController = new AbortController();

let receivedCustomersUpdatesCount = 0;
const onWatchCustomers = () => {
receivedCustomersUpdatesCount++;
};

const queryCustomers = db.select({ count: count() }).from(customers);
db.watch(
queryCustomers,
{ onResult: onWatchCustomers },
{
signal: customersAbortController.signal
}
);

// Ensures insert doesn't form part of initial result
await new Promise<void>((resolve) => setTimeout(resolve, throttleDuration));

await db
.insert(assets)
.values({
id: sql`uuid()`,
make: 'test',
customer_id: sql`uuid()`
})
.execute();

await new Promise<void>((resolve) => setTimeout(resolve, throttleDuration * 2));
assetsAbortController.abort();
customersAbortController.abort();

// There should be one initial result plus one throttled result
expect(receivedAssetsUpdatesCount).equals(2);

// Only the initial result should have yielded.
expect(receivedCustomersUpdatesCount).equals(1);
});

it('should handle watch onError', async () => {
const abortController = new AbortController();
const onResult = () => {}; // no-op
let receivedErrorCount = 0;

const receivedError = new Promise<void>(async (resolve) => {
const onError = () => {
receivedErrorCount++;
resolve();
};

const query = db
.select({
id: sql`fakeFunction()` // Simulate an error with invalid function
})
.from(assets);

db.watch(query, { onResult, onError }, { signal: abortController.signal, throttleMs: throttleDuration });
});
abortController.abort();

await receivedError;
expect(receivedErrorCount).equals(1);
});

it('should throttle watch overflow', async () => {
const overflowAbortController = new AbortController();
const updatesCount = 25;

let receivedWithManagedOverflowCount = 0;
const firstResultReceived = new Promise<void>((resolve) => {
const onResultOverflow = () => {
if (receivedWithManagedOverflowCount === 0) {
resolve();
}
receivedWithManagedOverflowCount++;
};
const query = db.select({ count: count() }).from(assets);
db.watch(query, { onResult: onResultOverflow }, { signal: overflowAbortController.signal, throttleMs: 1 });
});

await firstResultReceived;

// Perform a large number of inserts to trigger overflow
for (let i = 0; i < updatesCount; i++) {
db.insert(assets)
.values({
id: sql`uuid()`,
make: 'test',
customer_id: sql`uuid()`
})
.execute();
}

await new Promise<void>((resolve) => setTimeout(resolve, 1 * throttleDuration));

overflowAbortController.abort();

// This fluctuates between 3 and 4 based on timing, but should never be 25
expect(receivedWithManagedOverflowCount).greaterThan(2);
expect(receivedWithManagedOverflowCount).toBeLessThanOrEqual(4);
});
});