diff --git a/src/routers/replenishment.py b/src/routers/replenishment.py index 5a10e72..7b899a1 100644 --- a/src/routers/replenishment.py +++ b/src/routers/replenishment.py @@ -16,7 +16,12 @@ Page, ) from models import User -from schemas import ReplenishmentCreate, ReplenishmentModel, UserReplenishment +from schemas import ( + ReplenishmentCreate, + ReplenishmentUpdate, + ReplenishmentModel, + UserReplenishment, +) router = APIRouter( prefix="/replenishments", @@ -39,7 +44,7 @@ def update_replenishment( *, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), - replenishment: ReplenishmentCreate, + replenishment: ReplenishmentUpdate, replenishment_id: int, ) -> ReplenishmentModel: return services.update_replenishment( diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py index a59b541..83f8ce1 100644 --- a/src/schemas/__init__.py +++ b/src/schemas/__init__.py @@ -34,6 +34,7 @@ from .expense import ExpenseCreate, ExpenseUpdate, ExpenseModel, UserExpense from .replenishment import ( ReplenishmentCreate, + ReplenishmentUpdate, UserBalance, ReplenishmentModel, UserReplenishment, diff --git a/src/schemas/replenishment.py b/src/schemas/replenishment.py index a2ca635..5702e8a 100644 --- a/src/schemas/replenishment.py +++ b/src/schemas/replenishment.py @@ -9,6 +9,10 @@ class ReplenishmentCreate(BaseModel): descriptions: str +class ReplenishmentUpdate(ReplenishmentCreate): + time: datetime.datetime + + class ReplenishmentModel(ReplenishmentCreate): id: int time: datetime.datetime diff --git a/src/services/replenishment.py b/src/services/replenishment.py index 39344bf..d9249e9 100644 --- a/src/services/replenishment.py +++ b/src/services/replenishment.py @@ -9,7 +9,12 @@ from starlette.exceptions import HTTPException from models import Replenishment -from schemas import ReplenishmentCreate, ReplenishmentModel, UserReplenishment +from schemas import ( + ReplenishmentCreate, + ReplenishmentUpdate, + ReplenishmentModel, + UserReplenishment, +) def create_replenishment( @@ -31,8 +36,8 @@ def create_replenishment( def update_replenishment( - db: Session, user_id: int, replenishment: ReplenishmentCreate, replenishment_id: int -): + db: Session, user_id: int, replenishment: ReplenishmentUpdate, replenishment_id: int +) -> ReplenishmentModel: try: db.query(Replenishment).filter_by(id=replenishment_id, user_id=user_id).one() except exc.NoResultFound: diff --git a/tests/test_endpoints/test_replenishment_e.py b/tests/test_endpoints/test_replenishment_e.py index 0885867..b904cb5 100644 --- a/tests/test_endpoints/test_replenishment_e.py +++ b/tests/test_endpoints/test_replenishment_e.py @@ -3,7 +3,7 @@ from unittest.mock import Mock from dependencies import oauth -from schemas import ReplenishmentCreate +from schemas import ReplenishmentCreate, ReplenishmentUpdate from tests.conftest import async_return, client from tests.factories import ReplenishmentFactory, UserFactory @@ -45,21 +45,25 @@ def test_create_replenishment(self) -> None: def test_update_replenishment(self) -> None: replenishment = ReplenishmentFactory(user_id=self.user.id) - date_update_replenishment = ReplenishmentCreate( - descriptions="descriptions", amount=999.9 + time = "2018-08-03T10:51:42" + date_update_replenishment = ReplenishmentUpdate( + descriptions="descriptions", + amount=999.9, + time=time, ) data = client.put( f"/replenishments/{replenishment.id}/", json={ "descriptions": date_update_replenishment.descriptions, "amount": date_update_replenishment.amount, + "time": time, }, ) replenishments_data = { "id": data.json()["id"], "descriptions": date_update_replenishment.descriptions, "amount": date_update_replenishment.amount, - "time": data.json()["time"], + "time": time, "user": {"id": self.user.id, "login": self.user.login}, } assert data.status_code == 200 diff --git a/tests/test_services/test_replenishment_s.py b/tests/test_services/test_replenishment_s.py index 82becb5..7b3ef4d 100644 --- a/tests/test_services/test_replenishment_s.py +++ b/tests/test_services/test_replenishment_s.py @@ -4,7 +4,7 @@ from starlette.exceptions import HTTPException from models import Replenishment -from schemas import ReplenishmentCreate +from schemas import ReplenishmentCreate, ReplenishmentUpdate from services import ( create_replenishment, read_replenishments, @@ -31,17 +31,16 @@ def test_create_replenishment(session) -> None: def test_update_replenishment(session) -> None: user = UserFactory() replenishment = ReplenishmentFactory(user_id=user.id) - date_update_replenishment = ReplenishmentCreate( - descriptions="descriptions", amount=999.9 + time = datetime.datetime.now() + date_update_replenishment = ReplenishmentUpdate( + descriptions="descriptions", amount=999.9, time=time ) data = update_replenishment( session, user.id, date_update_replenishment, replenishment.id ) assert data.descriptions == date_update_replenishment.descriptions assert float(data.amount) == date_update_replenishment.amount - assert data.time.strftime("%Y-%m-%d %H:%M") == datetime.datetime.utcnow().strftime( - "%Y-%m-%d %H:%M" - ) + assert data.time == time assert data.user.id == user.id