Skip to content

Commit

Permalink
Add Async Factory implementation for SqlAlchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
nadege committed Jul 23, 2023
1 parent a4d3a9c commit 03dc0f2
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 4 deletions.
16 changes: 16 additions & 0 deletions factory/alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,19 @@ def _save(cls, model_class, session, args, kwargs):
elif session_persistence == SESSION_PERSISTENCE_COMMIT:
session.commit()
return obj


class SQLAlchemyModelAsyncFactory(SQLAlchemyModelFactory, base.AsyncFactory):
"""Async Factory for SQLAlchemy models. """

class Meta:
abstract = True

@classmethod
async def _create_model_async(cls, model_class, *args, **kwargs):
session = cls._meta.sqlalchemy_session
async with session.begin():
model = model_class(**kwargs)
session.add(model)

return model
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ dev =
flake8
isort
Pillow
SQLAlchemy
SQLAlchemy[asyncio]
sqlalchemy_utils
mongoengine
databases[sqlite]
wheel>=0.32.0
tox
zest.releaser[recommended]
Expand Down
21 changes: 20 additions & 1 deletion tests/alchemyapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""Helpers for testing SQLAlchemy apps."""
import os

from sqlalchemy import Column, Integer, Unicode, create_engine
from sqlalchemy import Boolean, Column, Integer, Unicode, create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker

try:
Expand All @@ -28,14 +29,24 @@
pg_host = os.environ.get('POSTGRES_HOST', 'localhost')
pg_port = os.environ.get('POSTGRES_PORT', '5432')
engine_name = f'postgresql+psycopg2://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_database}'
async_engine_name = None
else:
engine_name = 'sqlite://'
async_engine_name = 'sqlite+aiosqlite://'

session = scoped_session(sessionmaker())
engine = create_engine(engine_name)
session.configure(bind=engine)
Base = declarative_base()

if not async_engine_name:
async_engine = None
async_session = None
else:
async_engine = create_async_engine(async_engine_name)
async_sessionmaker = sessionmaker(async_engine, expire_on_commit=False, class_=AsyncSession)
async_session = async_sessionmaker()


class StandardModel(Base):
__tablename__ = 'StandardModelTable'
Expand Down Expand Up @@ -72,3 +83,11 @@ class SpecialFieldModel(Base):

id = Column(Integer(), primary_key=True)
session = Column(Unicode(20))


class NoteModel(Base):
__tablename__ = "NoteTable"

id = Column(Integer(), primary_key=True)
text = Column(Unicode(20))
completed = Column(Boolean(), default=False)
61 changes: 60 additions & 1 deletion tests/test_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""Tests for factory_boy/SQLAlchemy interactions."""

import asyncio
import unittest
from unittest import mock

Expand All @@ -13,7 +14,7 @@
from sqlalchemy_utils import create_database, database_exists, drop_database

import factory
from factory.alchemy import SQLAlchemyModelFactory
from factory.alchemy import SQLAlchemyModelAsyncFactory, SQLAlchemyModelFactory

from .alchemyapp import models

Expand Down Expand Up @@ -337,3 +338,61 @@ class Meta:

get_or_created_child = SpecialFieldWithGetOrCreateFactory()
self.assertEqual(get_or_created_child.session, "")


class NoteFactory(SQLAlchemyModelAsyncFactory):
class Meta:
model = models.NoteModel
sqlalchemy_session = models.async_session

text = factory.Sequence(lambda n: f"Text {n}")
completed = factory.Faker('boolean')


if models.async_engine:

class SQLAlchemyAsyncTestCase(unittest.TestCase):

def setUp(self):
super().setUp()

async def asyncSetUp():
async with models.async_engine.begin() as conn:
await conn.run_sync(models.Base.metadata.drop_all)
await conn.run_sync(models.Base.metadata.create_all)

NoteFactory.reset_sequence(0)

asyncio.run(asyncSetUp())

def tearDown(self):

async def asyncTearDown():

async with models.async_engine.begin() as conn:
await conn.run_sync(models.Base.metadata.drop_all)

asyncio.run(asyncTearDown())

def test_build(self):
note = NoteFactory.build()
self.assertEqual('Text 0', note.text)
self.assertIn(note.completed, [True, False])
self.assertIsNone(note.id)

def test_creation(self):

async def test():

note = await NoteFactory.create_async()
self.assertEqual('Text 0', note.text)
self.assertIn(note.completed, [True, False])
self.assertIsNotNone(note.id)

statement = sqlalchemy.select(
sqlalchemy.func.count(models.NoteModel.id)
).where(models.NoteModel.text == "Text 0")
count = await models.async_session.scalar(statement)
assert count == 1

asyncio.run(test())
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ passenv =
POSTGRES_HOST
POSTGRES_DATABASE
deps =
alchemy: SQLAlchemy
alchemy: SQLAlchemy[asyncio]
alchemy: sqlalchemy_utils
alchemy-sqlite: databases[sqlite]
mongo: mongoengine
django{32,40,41,main}: Pillow
django32: Django>=3.2,<3.3
Expand Down

0 comments on commit 03dc0f2

Please sign in to comment.