Skip to content

Commit

Permalink
Don't use sets of AnyUrl
Browse files Browse the repository at this point in the history
AnyUrl from pydantic is no longer hashable.
  • Loading branch information
CasperWA committed Nov 25, 2024
1 parent db2efab commit 5d60d8b
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 20 deletions.
10 changes: 9 additions & 1 deletion optimade_gateway/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@ def unique_base_urls(cls, value: list[LinksResource]) -> list[LinksResource]:
)

db_base_urls = [_.attributes.base_url for _ in value]
unique_base_urls = set(db_base_urls)

unique_base_urls = []
for base_url in db_base_urls:
if base_url is None:
continue
if base_url in unique_base_urls:
continue
unique_base_urls.append(base_url)

if len(db_base_urls) == len(unique_base_urls):
return value

Expand Down
4 changes: 2 additions & 2 deletions optimade_gateway/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Search(BaseModel):
] = set()

optimade_urls: Annotated[
set[AnyUrl],
list[AnyUrl],
Field(
description=(
"A list of OPTIMADE base URLs. If a versioned base URL is supplied it "
Expand All @@ -52,7 +52,7 @@ class Search(BaseModel):
"the server logic."
),
),
] = set()
] = []

endpoint: Annotated[
str,
Expand Down
6 changes: 3 additions & 3 deletions optimade_gateway/mongo/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async def afind(
else:
single_entry = isinstance(params, SingleEntryQueryParams)

response_fields = criteria.pop("fields", self.all_fields)
response_fields: set[str] = criteria.pop("fields", self.all_fields)

results, data_returned, more_data_available = await self._arun_db_query(
criteria=criteria,
Expand All @@ -264,8 +264,8 @@ async def afind(
include_fields = (
response_fields - self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS
)
bad_optimade_fields = set()
bad_provider_fields = set()
bad_optimade_fields: set[str] = set()
bad_provider_fields: set[str] = set()
for field in include_fields:
if field not in self.resource_mapper.ALL_ATTRIBUTES:
if field.startswith("_"):
Expand Down
6 changes: 3 additions & 3 deletions optimade_gateway/queries/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class in `optimade`, which defines the standard entry listing endpoint query
the gateway. To be known they need to be registered with the gateway
(currently not possible).
optimade_urls (set[AnyUrl]): A list of OPTIMADE base URLs. If a versioned base
optimade_urls (list[AnyUrl]): A list of OPTIMADE base URLs. If a versioned base
URL is supplied it will be used as is, as long as it represents a supported
version. If an un-versioned base URL, standard version negotiation will be
conducted to get the versioned base URL, which will be used as long as it
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
),
] = set(),
optimade_urls: Annotated[
set[AnyUrl],
list[AnyUrl],
Query(
description=(
"A unique list of OPTIMADE base URLs. If a versioned base URL is "
Expand All @@ -71,7 +71,7 @@ def __init__(
"which will be used as long as it represents a supported version."
),
),
] = set(),
] = [],
endpoint: Annotated[
str,
Query(
Expand Down
25 changes: 15 additions & 10 deletions optimade_gateway/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,22 @@ async def post_search(request: Request, search: Search) -> QueriesResponseSingle
# NOTE: It may be that the final list of base URLs (`base_urls`) contains the same
# provider(s), but with differring base URLS, if, for example, a versioned base URL
# is supplied.
base_urls: set[AnyUrl] = set()
base_urls: list[AnyUrl] = []

if search.database_ids:
databases = await databases_collection.get_multiple(
filter={"id": {"$in": await clean_python_types(search.database_ids)}}
)
base_urls |= {
get_resource_attribute(database, "attributes.base_url")
for database in databases
if get_resource_attribute(database, "attributes.base_url") is not None
}
base_urls.extend(
[
get_resource_attribute(database, "attributes.base_url")
for database in databases
if get_resource_attribute(database, "attributes.base_url") is not None
]
)

if search.optimade_urls:
base_urls |= {_ for _ in search.optimade_urls if _ is not None}
base_urls.extend([_ for _ in search.optimade_urls if _ is not None])

if not base_urls:
msg = "No (valid) OPTIMADE URLs with:"
Expand Down Expand Up @@ -128,10 +130,13 @@ async def post_search(request: Request, search: Search) -> QueriesResponseSingle

elif len(databases) < len(base_urls):
# There are unregistered databases, i.e., databases not in the local collection
current_base_urls: set[AnyUrl] = {
current_base_urls: list[AnyUrl] = [
get_resource_attribute(database, "attributes.base_url")
for database in databases
}
]
diff_base_urls = [
base_url for base_url in base_urls if base_url not in current_base_urls
]
databases.extend(
[
LinksResource(
Expand All @@ -150,7 +155,7 @@ async def post_search(request: Request, search: Search) -> QueriesResponseSingle
homepage=None,
),
)
for url in base_urls - current_base_urls
for url in diff_base_urls
]
)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def random_gateway() -> dict:
"""Get a random gateway currently in the MongoDB"""
from optimade_gateway.mongo.database import MONGO_DB

gateway_ids = set()
gateway_ids: set[str] = set()
async for gateway in MONGO_DB["gateways"].find(
filter={}, projection={"id": True, "_id": False}
):
Expand Down

0 comments on commit 5d60d8b

Please sign in to comment.