Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upload task worker results to Backend API #22

Merged
merged 13 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ dev/data/**
!dev/data/README.md
!dev/.env
id_rsa
*.json
12 changes: 3 additions & 9 deletions backend/src/mirrors_qa_backend/cli/mirrors.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import sys

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status
from mirrors_qa_backend.exceptions import MirrorsRequestError
from mirrors_qa_backend.extract import get_current_mirrors


def update_mirrors() -> None:
"""Update the list of active mirrors in the DB."""
logger.info("Updating mirrors list.")
try:
with Session.begin() as session:
results = create_or_update_mirror_status(session, get_current_mirrors())
except MirrorsRequestError as exc:
logger.info(f"error while updating mirrors: {exc}")
sys.exit(1)
with Session.begin() as session:
results = create_or_update_mirror_status(session, get_current_mirrors())
logger.info(
f"Updated mirrors list. Added {results.nb_mirrors_added} mirror(s), "
f"disabled {results.nb_mirrors_disabled} mirror(s)"
Expand Down
56 changes: 30 additions & 26 deletions backend/src/mirrors_qa_backend/cli/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.mirrors import get_enabled_mirrors
from mirrors_qa_backend.db.tests import create_test, expire_tests, list_tests
from mirrors_qa_backend.db.worker import get_idle_workers
from mirrors_qa_backend.enums import StatusEnum
Expand All @@ -16,6 +17,7 @@ def main(
):
while True:
with Session.begin() as session:
mirrors = get_enabled_mirrors(session)
# expire tests whose results have not been reported
expired_tests = expire_tests(
session,
Expand Down Expand Up @@ -44,36 +46,38 @@ def main(
f"No countries registered for idle worker {idle_worker.id}"
)
continue
for country in idle_worker.countries:
# While we have expired "unreported" tests, it is possible that
# a test for a country might still be PENDING as the interval
# for expiration and that of the scheduler might overlap.
# In such scenarios, we skip creating a test for that country.
pending_tests = list_tests(
session,
worker_id=idle_worker.id,
statuses=[StatusEnum.PENDING],
country_code=country.code,
# While we have expired "unreported" tests, it is possible that
# a test for a mirror might still be PENDING as the interval
# for expiration and that of the scheduler might overlap.
# In such scenarios, we skip creating a test for such workers.
pending_tests = list_tests(
session,
worker_id=idle_worker.id,
statuses=[StatusEnum.PENDING],
)

if pending_tests.nb_tests:
logger.info(
"Skipping creation of new test entries for "
f"{idle_worker.id} as {pending_tests.nb_tests} "
f"tests are still pending."
)
continue

if pending_tests.nb_tests:
# Create a test for each mirror from the countries the worker registered
for country in idle_worker.countries:
for mirror in mirrors:
new_test = create_test(
session=session,
worker=idle_worker,
country_code=country.code,
mirror=mirror,
)
logger.info(
"Skipping creation of new test entries for "
f"{idle_worker.id} as {pending_tests.nb_tests} "
f"tests are still pending for country {country.name}"
f"Created new test {new_test.id} for worker "
f"{idle_worker.id} in location {country.name} "
f"for mirror {mirror.id}"
)
continue

new_test = create_test(
session=session,
worker_id=idle_worker.id,
country_code=country.code,
status=StatusEnum.PENDING,
)
logger.info(
f"Created new test {new_test.id} for worker "
f"{idle_worker.id} in country {country.name}"
)

logger.info(f"Sleeping for {sleep_seconds} seconds.")
time.sleep(sleep_seconds)
79 changes: 51 additions & 28 deletions backend/src/mirrors_qa_backend/cli/worker.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,65 @@
import sys

import pycountry
from cryptography.hazmat.primitives import serialization

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.country import update_countries as update_db_countries
from mirrors_qa_backend.db.worker import create_worker as create_db_worker
from mirrors_qa_backend.db.worker import update_worker as update_db_worker


def get_country_mapping(country_codes: list[str]) -> dict[str, str]:
"""Fetch the country names from the country codes.

def create_worker(worker_id: str, private_key_data: bytes, country_codes: list[str]):
Maps the country code to the country name.
"""
country_mapping: dict[str, str] = {}
# Ensure all the countries are valid country codes
for country_code in country_codes:
if len(country_code) != 2: # noqa: PLR2004
logger.info(f"Country code '{country_code}' must be two characters long")
sys.exit(1)

if not pycountry.countries.get(alpha_2=country_code):
logger.info(f"'{country_code}' is not valid country code")
sys.exit(1)

try:
private_key = serialization.load_pem_private_key(
private_key_data, password=None
) # pyright: ignore[reportReturnType]
except Exception as exc:
logger.info(f"Unable to load private key: {exc}")
sys.exit(1)

try:
with Session.begin() as session:
create_db_worker(
session,
worker_id,
country_codes,
private_key, # pyright: ignore [reportGeneralTypeIssues, reportArgumentType]
raise ValueError(
f"Country code '{country_code}' must be two characters long"
)
except Exception as exc:
logger.info(f"error while creating worker: {exc}")
sys.exit(1)

if country := pycountry.countries.get(alpha_2=country_code):
country_mapping[country_code] = country.name
else:
raise ValueError(f"'{country_code}' is not valid country code")
return country_mapping


def create_worker(
worker_id: str, private_key_data: bytes, initial_country_codes: list[str]
):
"""Create a worker in the DB.

Assigns the countries for a worker to run tests from.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not clear enough. Maybe renaming to initial_country_codes would help

country_mapping = get_country_mapping(initial_country_codes)
private_key = serialization.load_pem_private_key(
private_key_data, password=None
) # pyright: ignore[reportReturnType]

with Session.begin() as session:
# Update the database with the countries in case those countries don't
# exist yet.
update_db_countries(session, country_mapping)
create_db_worker(
session,
worker_id,
initial_country_codes,
private_key, # pyright: ignore [reportGeneralTypeIssues, reportArgumentType]
)

logger.info(f"Created worker {worker_id} successfully")


def update_worker(worker_id: str, country_codes: list[str]):
"""Update worker's data.

Updates the ountries for a worker to run tests from.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

You see you have a duplicated block to assign countries to a worker.
Given this is only useful for tests, maybe the create and update methods should not include it
and the cli command should call a new method after create/update instead. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Makes sense.

country_mapping = get_country_mapping(country_codes)
with Session.begin() as session:
update_db_countries(session, country_mapping)
update_db_worker(session, worker_id, country_codes)
38 changes: 36 additions & 2 deletions backend/src/mirrors_qa_backend/db/country.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError
from mirrors_qa_backend.db.models import Country


def get_countries(session: OrmSession, *country_codes: str) -> list[Country]:
def get_countries(session: OrmSession, country_codes: list[str]) -> list[Country]:
"""Get countries with the provided country codes.

Gets all available countries if no country codes are provided.
"""
return list(
session.scalars(select(Country).where(Country.code.in_(country_codes))).all()
session.scalars(
select(Country).where(
(Country.code.in_(country_codes)) | (country_codes == [])
)
).all()
)


def get_country_or_none(session: OrmSession, country_code: str) -> Country | None:
return session.scalars(
select(Country).where(Country.code == country_code)
).one_or_none()


def get_country(session: OrmSession, country_code: str) -> Country:
if country := get_country_or_none(session, country_code):
return country
raise RecordDoesNotExistError(f"Country with code {country_code} does not exist.")


def create_country(
session: OrmSession, *, country_code: str, country_name: str
) -> Country:
"""Creates a new country in the database."""
session.execute(
insert(Country)
.values(code=country_code, name=country_name)
.on_conflict_do_nothing(index_elements=["code"])
)
return get_country(session, country_code)


def update_countries(session: OrmSession, country_mapping: dict[str, str]) -> None:
"""Updates the list of countries in the database."""
for country_code, country_name in country_mapping.items():
create_country(session, country_code=country_code, country_name=country_name)
53 changes: 25 additions & 28 deletions backend/src/mirrors_qa_backend/db/mirrors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession
from sqlalchemy.orm import selectinload

from mirrors_qa_backend import logger, schemas
from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.exceptions import EmptyMirrorsError
from mirrors_qa_backend.db.exceptions import EmptyMirrorsError, RecordDoesNotExistError
from mirrors_qa_backend.db.models import Mirror


@dataclass
Expand All @@ -24,7 +23,7 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int:
"""
nb_created = 0
for mirror in mirrors:
db_mirror = models.Mirror(
db_mirror = Mirror(
id=mirror.id,
base_url=mirror.base_url,
enabled=mirror.enabled,
Expand All @@ -38,20 +37,8 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int:
as_only=mirror.as_only,
other_countries=mirror.other_countries,
)
# Ensure the country exists for the mirror
country = session.scalars(
select(models.Country).where(models.Country.code == mirror.country.code)
).one_or_none()

if country is None:
country = models.Country(code=mirror.country.code, name=mirror.country.name)
session.add(country)

db_mirror.country = country
session.add(db_mirror)
logger.debug(
f"Registered new mirror: {db_mirror.id} for country: {country.name}"
)
logger.debug(f"Registered new mirror: {db_mirror.id}.")
nb_created += 1
return nb_created

Expand Down Expand Up @@ -79,9 +66,8 @@ def create_or_update_mirror_status(
# Map the id (hostname) of each mirror from the database for comparison
# against the id of mirrors in current_mirrors. To be used in determining
# if this mirror should be disabled
query = select(models.Mirror).options(selectinload(models.Mirror.country))
db_mirrors: dict[str, models.Mirror] = {
mirror.id: mirror for mirror in session.scalars(query).all()
db_mirrors: dict[str, Mirror] = {
mirror.id: mirror for mirror in session.scalars(select(Mirror)).all()
}

# Create any mirror that doesn't exist on the database
Expand All @@ -95,19 +81,30 @@ def create_or_update_mirror_status(
# exists in the list, re-enable it
for db_mirror_id, db_mirror in db_mirrors.items():
if db_mirror_id not in current_mirrors:
logger.debug(
f"Disabling mirror: {db_mirror.id} for "
f"country: {db_mirror.country.name}"
)
logger.debug(f"Disabling mirror: {db_mirror.id}")
db_mirror.enabled = False
session.add(db_mirror)
result.nb_mirrors_disabled += 1
elif not db_mirror.enabled: # re-enable mirror if it was disabled
logger.debug(
f"Re-enabling mirror: {db_mirror.id} for "
f"country: {db_mirror.country.name}"
)
logger.debug(f"Re-enabling mirror: {db_mirror.id}")
db_mirror.enabled = True
session.add(db_mirror)
result.nb_mirrors_added += 1
return result


def get_mirror(session: OrmSession, mirror_id: str) -> Mirror:
"""Get a mirror from the DB."""
mirror = session.scalars(select(Mirror).where(Mirror.id == mirror_id)).one_or_none()
if mirror is None:
raise RecordDoesNotExistError(f"Mirror with id: {mirror_id} does not exist.")
return mirror


def get_enabled_mirrors(session: OrmSession) -> list[Mirror]:
"""Get all the enabled mirrors from the DB"""
return list(
session.scalars(
select(Mirror).where(Mirror.enabled == True) # noqa: E712
).all()
)
Loading
Loading