-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
375 additions
and
413 deletions.
There are no files selected for viewing
152 changes: 87 additions & 65 deletions
152
po8klasie_fastapi/app/api/comparison/comparison_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,90 @@ | ||
from __future__ import annotations | ||
|
||
from enum import Enum | ||
from typing import List, TypedDict | ||
|
||
|
||
class ComparisonResultEnum(Enum): | ||
MATCH = "match" | ||
NEUTRAL = "neutral" | ||
|
||
|
||
class ComparisonItemT(TypedDict): | ||
value: int | str | ||
comparison_result: ComparisonResultEnum | ||
|
||
|
||
def is_iterable(x) -> bool: | ||
try: | ||
iter(x) | ||
return not isinstance(x, str) | ||
except TypeError: | ||
return False | ||
|
||
|
||
def find_intersection(institutions, property_key=None, getter_fn=None) -> set: | ||
def get_single_set(idx): | ||
value = None | ||
if not property_key and not getter_fn: | ||
raise Exception("No property key nor getter function specified") | ||
|
||
if property_key: | ||
institution = institutions[idx] | ||
value = getattr(institution, property_key, None) | ||
if getter_fn: | ||
value = getter_fn(institutions[idx].rspo) | ||
|
||
if is_iterable(value): | ||
return set(value) | ||
return {value} | ||
|
||
intersection = None | ||
for i in range(len(institutions)): | ||
if i == 0: | ||
intersection = get_single_set(i) | ||
intersection = intersection.intersection(get_single_set(i)) | ||
return intersection | ||
|
||
|
||
def get_comparison_result(is_in_intersection) -> ComparisonResultEnum: | ||
if is_in_intersection: | ||
return ComparisonResultEnum.MATCH | ||
return ComparisonResultEnum.NEUTRAL | ||
|
||
|
||
def get_comparison_item(value, intersection) -> ComparisonItemT | List[ComparisonItemT]: | ||
if not is_iterable(value): | ||
from dataclasses import dataclass | ||
|
||
from po8klasie_fastapi.app.api.comparison.schemas import ( | ||
ComparisonInstitutionDataSchema, | ||
ComparisonComparableDataSchema, | ||
ComparisonResultEnum, | ||
ComparisonInstitution, | ||
) | ||
|
||
|
||
@dataclass | ||
class ComparisonInternalStateItem: | ||
institution_data: ComparisonInstitutionDataSchema | ||
comparable_data: ComparisonComparableDataSchema | ||
|
||
|
||
def create_comparison_internal_state(institutions_to_compare): | ||
for institution in institutions_to_compare: | ||
parsed_institution_data = ComparisonInstitutionDataSchema.parse_institution( | ||
institution | ||
) | ||
parsed_comparable_data = ComparisonComparableDataSchema.parse_institution( | ||
institution | ||
) | ||
yield ComparisonInternalStateItem( | ||
institution_data=parsed_institution_data, | ||
comparable_data=parsed_comparable_data, | ||
) | ||
|
||
|
||
def compare(institutions_to_compare): | ||
internal_comparison_state = list( | ||
create_comparison_internal_state(institutions_to_compare) | ||
) | ||
comparable_fields = ComparisonComparableDataSchema.__fields__.keys() | ||
|
||
def get_field_value(comparison_state_item, field_name): | ||
value = getattr(comparison_state_item.comparable_data, field_name) | ||
if isinstance(value, list): | ||
return tuple(value) | ||
return value | ||
|
||
all_fields_values = { | ||
field: set( | ||
get_field_value(comparison_state_item, field) | ||
for comparison_state_item in internal_comparison_state | ||
) | ||
for field in comparable_fields | ||
} | ||
|
||
def get_single_field_comparison_result(field_name): | ||
is_field_value_common = len(all_fields_values[field_name]) == 1 | ||
if is_field_value_common: | ||
return ComparisonResultEnum.MATCH | ||
return ComparisonResultEnum.NEUTRAL | ||
|
||
def get_single_field_comparison(field_name, comparison_institution): | ||
return { | ||
"value": value, | ||
"comparison_result": get_comparison_result(value in intersection), | ||
"value": getattr( | ||
comparison_institution.comparable_data, field_name | ||
), | ||
"comparison_result": get_single_field_comparison_result(field_name), | ||
} | ||
|
||
return [ | ||
{ | ||
"value": value_item, | ||
"comparison_result": get_comparison_result(value_item in intersection), | ||
} | ||
for value_item in value | ||
] | ||
def get_iterable_field_comparison(field_name, comparison_institution): | ||
for item in getattr(comparison_institution.comparable_data, field_name): | ||
set_of_all_lists = all_fields_values[field_name] | ||
is_item_in_all_lists = all(item in list_for_field for list_for_field in set_of_all_lists) | ||
|
||
yield { | ||
"value": item, | ||
"comparison_result": ComparisonResultEnum.MATCH if is_item_in_all_lists else ComparisonResultEnum.NEUTRAL, | ||
} | ||
|
||
def is_field_iterable(field_name, comparison_institution): | ||
value = getattr(comparison_institution.comparable_data, field_name) | ||
return isinstance(value, list) | ||
|
||
for comparison_state_institution in internal_comparison_state: | ||
yield ComparisonInstitution( | ||
comparison={ | ||
field: ( | ||
list(get_iterable_field_comparison(field, comparison_state_institution)) | ||
if is_field_iterable(field, comparison_state_institution) | ||
else get_single_field_comparison(field, comparison_state_institution) | ||
) | ||
for field in comparable_fields | ||
}, | ||
**comparison_state_institution.institution_data.dict() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,52 @@ | ||
from __future__ import annotations | ||
from enum import Enum | ||
|
||
from typing import List | ||
from pydantic import validator, create_model, Field | ||
from typing import Any | ||
|
||
from po8klasie_fastapi.app.api.comparison.comparison_utils import ComparisonResultEnum | ||
from po8klasie_fastapi.app.api.schemas import InstitutionSourcingSchemaMixin | ||
from po8klasie_fastapi.app.lib.router_utils import CamelCasedModel | ||
|
||
|
||
class ComparisonItemSchema(CamelCasedModel): | ||
value: int | str | ||
class ComparisonInstitutionDataSchema(CamelCasedModel, InstitutionSourcingSchemaMixin): | ||
rspo: str | ||
name: str | ||
|
||
|
||
class ComparisonComparableDataSchema(CamelCasedModel, InstitutionSourcingSchemaMixin): | ||
is_public: bool | ||
city: str | ||
available_languages: list[str] | ||
classes: list[str] | ||
|
||
@validator("classes", pre=True) | ||
def preprocess_classes(cls, classes): | ||
return ["-".join(class_.extended_subjects) for class_ in classes] | ||
|
||
|
||
class ComparisonResultEnum(Enum): | ||
MATCH = "match" | ||
NEUTRAL = "neutral" | ||
|
||
|
||
class ComparisonField(CamelCasedModel): | ||
value: Any | ||
comparison_result: ComparisonResultEnum | ||
|
||
|
||
class ComparisonItemsSchema(CamelCasedModel): | ||
is_public: ComparisonItemSchema | ||
city: ComparisonItemSchema | ||
available_languages: List[ComparisonItemSchema] | ||
classes: List[ComparisonItemSchema] | ||
ComparisonInstitutionResultSnakeCase = create_model( | ||
"ComparisonInstitutionResult", | ||
**{ | ||
field_name: (ComparisonField | list[ComparisonField], Field(title=field_name)) | ||
for field_name in ComparisonComparableDataSchema.__fields__.keys() | ||
} | ||
) | ||
|
||
|
||
class ComparisonInstitutionSchema(CamelCasedModel): | ||
name: str | ||
rspo: str | ||
comparison: ComparisonItemsSchema | ||
class ComparisonInstitutionResult( | ||
ComparisonInstitutionResultSnakeCase, CamelCasedModel | ||
): | ||
pass | ||
|
||
|
||
class ComparisonInstitution(ComparisonInstitutionDataSchema): | ||
comparison: ComparisonInstitutionResult |
Oops, something went wrong.