Skip to content

Commit

Permalink
Refactor API to use schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
micorix committed May 22, 2023
1 parent b6b9dc9 commit 69f9831
Show file tree
Hide file tree
Showing 11 changed files with 375 additions and 413 deletions.
152 changes: 87 additions & 65 deletions po8klasie_fastapi/app/api/comparison/comparison_utils.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,90 @@
from __future__ import annotations

from enum import Enum
from typing import List, TypedDict


class ComparisonResultEnum(Enum):
MATCH = "match"
NEUTRAL = "neutral"


class ComparisonItemT(TypedDict):
value: int | str
comparison_result: ComparisonResultEnum


def is_iterable(x) -> bool:
try:
iter(x)
return not isinstance(x, str)
except TypeError:
return False


def find_intersection(institutions, property_key=None, getter_fn=None) -> set:
def get_single_set(idx):
value = None
if not property_key and not getter_fn:
raise Exception("No property key nor getter function specified")

if property_key:
institution = institutions[idx]
value = getattr(institution, property_key, None)
if getter_fn:
value = getter_fn(institutions[idx].rspo)

if is_iterable(value):
return set(value)
return {value}

intersection = None
for i in range(len(institutions)):
if i == 0:
intersection = get_single_set(i)
intersection = intersection.intersection(get_single_set(i))
return intersection


def get_comparison_result(is_in_intersection) -> ComparisonResultEnum:
if is_in_intersection:
return ComparisonResultEnum.MATCH
return ComparisonResultEnum.NEUTRAL


def get_comparison_item(value, intersection) -> ComparisonItemT | List[ComparisonItemT]:
if not is_iterable(value):
from dataclasses import dataclass

from po8klasie_fastapi.app.api.comparison.schemas import (
ComparisonInstitutionDataSchema,
ComparisonComparableDataSchema,
ComparisonResultEnum,
ComparisonInstitution,
)


@dataclass
class ComparisonInternalStateItem:
institution_data: ComparisonInstitutionDataSchema
comparable_data: ComparisonComparableDataSchema


def create_comparison_internal_state(institutions_to_compare):
for institution in institutions_to_compare:
parsed_institution_data = ComparisonInstitutionDataSchema.parse_institution(
institution
)
parsed_comparable_data = ComparisonComparableDataSchema.parse_institution(
institution
)
yield ComparisonInternalStateItem(
institution_data=parsed_institution_data,
comparable_data=parsed_comparable_data,
)


def compare(institutions_to_compare):
internal_comparison_state = list(
create_comparison_internal_state(institutions_to_compare)
)
comparable_fields = ComparisonComparableDataSchema.__fields__.keys()

def get_field_value(comparison_state_item, field_name):
value = getattr(comparison_state_item.comparable_data, field_name)
if isinstance(value, list):
return tuple(value)
return value

all_fields_values = {
field: set(
get_field_value(comparison_state_item, field)
for comparison_state_item in internal_comparison_state
)
for field in comparable_fields
}

def get_single_field_comparison_result(field_name):
is_field_value_common = len(all_fields_values[field_name]) == 1
if is_field_value_common:
return ComparisonResultEnum.MATCH
return ComparisonResultEnum.NEUTRAL

def get_single_field_comparison(field_name, comparison_institution):
return {
"value": value,
"comparison_result": get_comparison_result(value in intersection),
"value": getattr(
comparison_institution.comparable_data, field_name
),
"comparison_result": get_single_field_comparison_result(field_name),
}

return [
{
"value": value_item,
"comparison_result": get_comparison_result(value_item in intersection),
}
for value_item in value
]
def get_iterable_field_comparison(field_name, comparison_institution):
for item in getattr(comparison_institution.comparable_data, field_name):
set_of_all_lists = all_fields_values[field_name]
is_item_in_all_lists = all(item in list_for_field for list_for_field in set_of_all_lists)

yield {
"value": item,
"comparison_result": ComparisonResultEnum.MATCH if is_item_in_all_lists else ComparisonResultEnum.NEUTRAL,
}

def is_field_iterable(field_name, comparison_institution):
value = getattr(comparison_institution.comparable_data, field_name)
return isinstance(value, list)

for comparison_state_institution in internal_comparison_state:
yield ComparisonInstitution(
comparison={
field: (
list(get_iterable_field_comparison(field, comparison_state_institution))
if is_field_iterable(field, comparison_state_institution)
else get_single_field_comparison(field, comparison_state_institution)
)
for field in comparable_fields
},
**comparison_state_institution.institution_data.dict()
)
76 changes: 6 additions & 70 deletions po8klasie_fastapi/app/api/comparison/router.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,22 @@
from typing import List

from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session

from po8klasie_fastapi.app.api.comparison.comparison_utils import (
find_intersection,
get_comparison_item,
)
from po8klasie_fastapi.app.api.comparison.schemas import ComparisonInstitutionSchema
from po8klasie_fastapi.app.api.institution.router import (
school_router_secondary_school_entities,
compare,
)
from po8klasie_fastapi.app.api.comparison.schemas import ComparisonInstitution
from po8klasie_fastapi.app.institution.models import (
SecondarySchoolInstitution,
query_secondary_school_institutions,
)
from po8klasie_fastapi.app.institution_classes.models import (
SecondarySchoolInstitutionClass,
query_current_classes,
query_institutions,
)
from po8klasie_fastapi.db.db import get_db

comparison_router = APIRouter()

comparison_router_secondary_school_entities = [*school_router_secondary_school_entities]


def group_classes_by_rspo(db, rspos: List[str]):
classes_by_rspo = {}
classes = (
query_current_classes(db)
.with_entities(
SecondarySchoolInstitutionClass.extended_subjects,
SecondarySchoolInstitutionClass.class_name,
SecondarySchoolInstitutionClass.institution_rspo,
)
.filter(SecondarySchoolInstitutionClass.institution_rspo.in_(rspos))
.all()
)
for class_ in classes:
rspo = class_.institution_rspo
extended_subjects = "-".join(class_.extended_subjects)
if rspo in classes_by_rspo and extended_subjects not in classes_by_rspo[rspo]:
classes_by_rspo[rspo].append(extended_subjects)
else:
classes_by_rspo[rspo] = [extended_subjects]
return classes_by_rspo


@comparison_router.get("/", response_model=List[ComparisonInstitutionSchema])
@comparison_router.get("/", response_model=list[ComparisonInstitution])
def route_comparison(
rspo: List[str] = Query(default=[]), db: Session = Depends(get_db)
):
Expand All @@ -59,41 +26,10 @@ def route_comparison(
)

institutions: List[SecondarySchoolInstitution] = (
query_secondary_school_institutions(db, school_router_secondary_school_entities)
.filter(SecondarySchoolInstitution.rspo.in_(rspo))
.all()
query_institutions(db).filter(SecondarySchoolInstitution.rspo.in_(rspo)).all()
)

if len(institutions) != len(rspo):
raise HTTPException(status_code=404, detail="School(s) not found")

rspos = [institution.rspo for institution in institutions]
classes_by_rspo = group_classes_by_rspo(db, rspos)

city_intersection = find_intersection(institutions, property_key="city")
is_public_intersection = find_intersection(institutions, property_key="is_public")
available_languages_intersection = find_intersection(
institutions, property_key="available_languages"
)
classes_intersection = find_intersection(
institutions, getter_fn=lambda rspo: classes_by_rspo.get(rspo)
)

for institution in institutions:
yield {
"name": institution.name,
"rspo": institution.rspo,
"comparison": {
"is_public": get_comparison_item(
institution.is_public, is_public_intersection
),
"city": get_comparison_item(institution.city, city_intersection),
"available_languages": get_comparison_item(
institution.available_languages,
available_languages_intersection,
),
"classes": get_comparison_item(
classes_by_rspo.get(institution.rspo, []), classes_intersection
),
},
}
return compare(institutions)
56 changes: 42 additions & 14 deletions po8klasie_fastapi/app/api/comparison/schemas.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,52 @@
from __future__ import annotations
from enum import Enum

from typing import List
from pydantic import validator, create_model, Field
from typing import Any

from po8klasie_fastapi.app.api.comparison.comparison_utils import ComparisonResultEnum
from po8klasie_fastapi.app.api.schemas import InstitutionSourcingSchemaMixin
from po8klasie_fastapi.app.lib.router_utils import CamelCasedModel


class ComparisonItemSchema(CamelCasedModel):
value: int | str
class ComparisonInstitutionDataSchema(CamelCasedModel, InstitutionSourcingSchemaMixin):
rspo: str
name: str


class ComparisonComparableDataSchema(CamelCasedModel, InstitutionSourcingSchemaMixin):
is_public: bool
city: str
available_languages: list[str]
classes: list[str]

@validator("classes", pre=True)
def preprocess_classes(cls, classes):
return ["-".join(class_.extended_subjects) for class_ in classes]


class ComparisonResultEnum(Enum):
MATCH = "match"
NEUTRAL = "neutral"


class ComparisonField(CamelCasedModel):
value: Any
comparison_result: ComparisonResultEnum


class ComparisonItemsSchema(CamelCasedModel):
is_public: ComparisonItemSchema
city: ComparisonItemSchema
available_languages: List[ComparisonItemSchema]
classes: List[ComparisonItemSchema]
ComparisonInstitutionResultSnakeCase = create_model(
"ComparisonInstitutionResult",
**{
field_name: (ComparisonField | list[ComparisonField], Field(title=field_name))
for field_name in ComparisonComparableDataSchema.__fields__.keys()
}
)


class ComparisonInstitutionSchema(CamelCasedModel):
name: str
rspo: str
comparison: ComparisonItemsSchema
class ComparisonInstitutionResult(
ComparisonInstitutionResultSnakeCase, CamelCasedModel
):
pass


class ComparisonInstitution(ComparisonInstitutionDataSchema):
comparison: ComparisonInstitutionResult
Loading

0 comments on commit 69f9831

Please sign in to comment.