diff --git a/.gitignore b/.gitignore index 4909849..e79e35b 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,4 @@ dev/data/** !dev/data/README.md !dev/.env id_rsa +*.json diff --git a/backend/src/mirrors_qa_backend/cli/mirrors.py b/backend/src/mirrors_qa_backend/cli/mirrors.py index 0c931af..3e14fa6 100644 --- a/backend/src/mirrors_qa_backend/cli/mirrors.py +++ b/backend/src/mirrors_qa_backend/cli/mirrors.py @@ -1,20 +1,14 @@ -import sys - from mirrors_qa_backend import logger from mirrors_qa_backend.db import Session from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status -from mirrors_qa_backend.exceptions import MirrorsRequestError from mirrors_qa_backend.extract import get_current_mirrors def update_mirrors() -> None: + """Update the list of active mirrors in the DB.""" logger.info("Updating mirrors list.") - try: - with Session.begin() as session: - results = create_or_update_mirror_status(session, get_current_mirrors()) - except MirrorsRequestError as exc: - logger.info(f"error while updating mirrors: {exc}") - sys.exit(1) + with Session.begin() as session: + results = create_or_update_mirror_status(session, get_current_mirrors()) logger.info( f"Updated mirrors list. Added {results.nb_mirrors_added} mirror(s), " f"disabled {results.nb_mirrors_disabled} mirror(s)" diff --git a/backend/src/mirrors_qa_backend/cli/scheduler.py b/backend/src/mirrors_qa_backend/cli/scheduler.py index e15ce45..13f411a 100644 --- a/backend/src/mirrors_qa_backend/cli/scheduler.py +++ b/backend/src/mirrors_qa_backend/cli/scheduler.py @@ -3,6 +3,7 @@ from mirrors_qa_backend import logger from mirrors_qa_backend.db import Session +from mirrors_qa_backend.db.mirrors import get_enabled_mirrors from mirrors_qa_backend.db.tests import create_test, expire_tests, list_tests from mirrors_qa_backend.db.worker import get_idle_workers from mirrors_qa_backend.enums import StatusEnum @@ -16,6 +17,7 @@ def main( ): while True: with Session.begin() as session: + mirrors = get_enabled_mirrors(session) # expire tests whose results have not been reported expired_tests = expire_tests( session, @@ -44,36 +46,38 @@ def main( f"No countries registered for idle worker {idle_worker.id}" ) continue - for country in idle_worker.countries: - # While we have expired "unreported" tests, it is possible that - # a test for a country might still be PENDING as the interval - # for expiration and that of the scheduler might overlap. - # In such scenarios, we skip creating a test for that country. - pending_tests = list_tests( - session, - worker_id=idle_worker.id, - statuses=[StatusEnum.PENDING], - country_code=country.code, + # While we have expired "unreported" tests, it is possible that + # a test for a mirror might still be PENDING as the interval + # for expiration and that of the scheduler might overlap. + # In such scenarios, we skip creating a test for such workers. + pending_tests = list_tests( + session, + worker_id=idle_worker.id, + statuses=[StatusEnum.PENDING], + ) + + if pending_tests.nb_tests: + logger.info( + "Skipping creation of new test entries for " + f"{idle_worker.id} as {pending_tests.nb_tests} " + f"tests are still pending." ) + continue - if pending_tests.nb_tests: + # Create a test for each mirror from the countries the worker registered + for country in idle_worker.countries: + for mirror in mirrors: + new_test = create_test( + session=session, + worker=idle_worker, + country_code=country.code, + mirror=mirror, + ) logger.info( - "Skipping creation of new test entries for " - f"{idle_worker.id} as {pending_tests.nb_tests} " - f"tests are still pending for country {country.name}" + f"Created new test {new_test.id} for worker " + f"{idle_worker.id} in location {country.name} " + f"for mirror {mirror.id}" ) - continue - - new_test = create_test( - session=session, - worker_id=idle_worker.id, - country_code=country.code, - status=StatusEnum.PENDING, - ) - logger.info( - f"Created new test {new_test.id} for worker " - f"{idle_worker.id} in country {country.name}" - ) logger.info(f"Sleeping for {sleep_seconds} seconds.") time.sleep(sleep_seconds) diff --git a/backend/src/mirrors_qa_backend/cli/worker.py b/backend/src/mirrors_qa_backend/cli/worker.py index bb3e443..6377881 100644 --- a/backend/src/mirrors_qa_backend/cli/worker.py +++ b/backend/src/mirrors_qa_backend/cli/worker.py @@ -1,42 +1,65 @@ -import sys - import pycountry from cryptography.hazmat.primitives import serialization from mirrors_qa_backend import logger from mirrors_qa_backend.db import Session +from mirrors_qa_backend.db.country import update_countries as update_db_countries from mirrors_qa_backend.db.worker import create_worker as create_db_worker +from mirrors_qa_backend.db.worker import update_worker as update_db_worker + +def get_country_mapping(country_codes: list[str]) -> dict[str, str]: + """Fetch the country names from the country codes. -def create_worker(worker_id: str, private_key_data: bytes, country_codes: list[str]): + Maps the country code to the country name. + """ + country_mapping: dict[str, str] = {} # Ensure all the countries are valid country codes for country_code in country_codes: if len(country_code) != 2: # noqa: PLR2004 - logger.info(f"Country code '{country_code}' must be two characters long") - sys.exit(1) - - if not pycountry.countries.get(alpha_2=country_code): - logger.info(f"'{country_code}' is not valid country code") - sys.exit(1) - - try: - private_key = serialization.load_pem_private_key( - private_key_data, password=None - ) # pyright: ignore[reportReturnType] - except Exception as exc: - logger.info(f"Unable to load private key: {exc}") - sys.exit(1) - - try: - with Session.begin() as session: - create_db_worker( - session, - worker_id, - country_codes, - private_key, # pyright: ignore [reportGeneralTypeIssues, reportArgumentType] + raise ValueError( + f"Country code '{country_code}' must be two characters long" ) - except Exception as exc: - logger.info(f"error while creating worker: {exc}") - sys.exit(1) + + if country := pycountry.countries.get(alpha_2=country_code): + country_mapping[country_code] = country.name + else: + raise ValueError(f"'{country_code}' is not valid country code") + return country_mapping + + +def create_worker( + worker_id: str, private_key_data: bytes, initial_country_codes: list[str] +): + """Create a worker in the DB. + + Assigns the countries for a worker to run tests from. + """ + country_mapping = get_country_mapping(initial_country_codes) + private_key = serialization.load_pem_private_key( + private_key_data, password=None + ) # pyright: ignore[reportReturnType] + + with Session.begin() as session: + # Update the database with the countries in case those countries don't + # exist yet. + update_db_countries(session, country_mapping) + create_db_worker( + session, + worker_id, + initial_country_codes, + private_key, # pyright: ignore [reportGeneralTypeIssues, reportArgumentType] + ) logger.info(f"Created worker {worker_id} successfully") + + +def update_worker(worker_id: str, country_codes: list[str]): + """Update worker's data. + + Updates the ountries for a worker to run tests from. + """ + country_mapping = get_country_mapping(country_codes) + with Session.begin() as session: + update_db_countries(session, country_mapping) + update_db_worker(session, worker_id, country_codes) diff --git a/backend/src/mirrors_qa_backend/db/country.py b/backend/src/mirrors_qa_backend/db/country.py index a3c173c..f085c11 100644 --- a/backend/src/mirrors_qa_backend/db/country.py +++ b/backend/src/mirrors_qa_backend/db/country.py @@ -1,12 +1,22 @@ from sqlalchemy import select +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session as OrmSession +from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError from mirrors_qa_backend.db.models import Country -def get_countries(session: OrmSession, *country_codes: str) -> list[Country]: +def get_countries(session: OrmSession, country_codes: list[str]) -> list[Country]: + """Get countries with the provided country codes. + + Gets all available countries if no country codes are provided. + """ return list( - session.scalars(select(Country).where(Country.code.in_(country_codes))).all() + session.scalars( + select(Country).where( + (Country.code.in_(country_codes)) | (country_codes == []) + ) + ).all() ) @@ -14,3 +24,27 @@ def get_country_or_none(session: OrmSession, country_code: str) -> Country | Non return session.scalars( select(Country).where(Country.code == country_code) ).one_or_none() + + +def get_country(session: OrmSession, country_code: str) -> Country: + if country := get_country_or_none(session, country_code): + return country + raise RecordDoesNotExistError(f"Country with code {country_code} does not exist.") + + +def create_country( + session: OrmSession, *, country_code: str, country_name: str +) -> Country: + """Creates a new country in the database.""" + session.execute( + insert(Country) + .values(code=country_code, name=country_name) + .on_conflict_do_nothing(index_elements=["code"]) + ) + return get_country(session, country_code) + + +def update_countries(session: OrmSession, country_mapping: dict[str, str]) -> None: + """Updates the list of countries in the database.""" + for country_code, country_name in country_mapping.items(): + create_country(session, country_code=country_code, country_name=country_name) diff --git a/backend/src/mirrors_qa_backend/db/mirrors.py b/backend/src/mirrors_qa_backend/db/mirrors.py index 904672d..fa2454f 100644 --- a/backend/src/mirrors_qa_backend/db/mirrors.py +++ b/backend/src/mirrors_qa_backend/db/mirrors.py @@ -2,11 +2,10 @@ from sqlalchemy import select from sqlalchemy.orm import Session as OrmSession -from sqlalchemy.orm import selectinload from mirrors_qa_backend import logger, schemas -from mirrors_qa_backend.db import models -from mirrors_qa_backend.db.exceptions import EmptyMirrorsError +from mirrors_qa_backend.db.exceptions import EmptyMirrorsError, RecordDoesNotExistError +from mirrors_qa_backend.db.models import Mirror @dataclass @@ -24,7 +23,7 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int: """ nb_created = 0 for mirror in mirrors: - db_mirror = models.Mirror( + db_mirror = Mirror( id=mirror.id, base_url=mirror.base_url, enabled=mirror.enabled, @@ -38,20 +37,8 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int: as_only=mirror.as_only, other_countries=mirror.other_countries, ) - # Ensure the country exists for the mirror - country = session.scalars( - select(models.Country).where(models.Country.code == mirror.country.code) - ).one_or_none() - - if country is None: - country = models.Country(code=mirror.country.code, name=mirror.country.name) - session.add(country) - - db_mirror.country = country session.add(db_mirror) - logger.debug( - f"Registered new mirror: {db_mirror.id} for country: {country.name}" - ) + logger.debug(f"Registered new mirror: {db_mirror.id}.") nb_created += 1 return nb_created @@ -79,9 +66,8 @@ def create_or_update_mirror_status( # Map the id (hostname) of each mirror from the database for comparison # against the id of mirrors in current_mirrors. To be used in determining # if this mirror should be disabled - query = select(models.Mirror).options(selectinload(models.Mirror.country)) - db_mirrors: dict[str, models.Mirror] = { - mirror.id: mirror for mirror in session.scalars(query).all() + db_mirrors: dict[str, Mirror] = { + mirror.id: mirror for mirror in session.scalars(select(Mirror)).all() } # Create any mirror that doesn't exist on the database @@ -95,19 +81,30 @@ def create_or_update_mirror_status( # exists in the list, re-enable it for db_mirror_id, db_mirror in db_mirrors.items(): if db_mirror_id not in current_mirrors: - logger.debug( - f"Disabling mirror: {db_mirror.id} for " - f"country: {db_mirror.country.name}" - ) + logger.debug(f"Disabling mirror: {db_mirror.id}") db_mirror.enabled = False session.add(db_mirror) result.nb_mirrors_disabled += 1 elif not db_mirror.enabled: # re-enable mirror if it was disabled - logger.debug( - f"Re-enabling mirror: {db_mirror.id} for " - f"country: {db_mirror.country.name}" - ) + logger.debug(f"Re-enabling mirror: {db_mirror.id}") db_mirror.enabled = True session.add(db_mirror) result.nb_mirrors_added += 1 return result + + +def get_mirror(session: OrmSession, mirror_id: str) -> Mirror: + """Get a mirror from the DB.""" + mirror = session.scalars(select(Mirror).where(Mirror.id == mirror_id)).one_or_none() + if mirror is None: + raise RecordDoesNotExistError(f"Mirror with id: {mirror_id} does not exist.") + return mirror + + +def get_enabled_mirrors(session: OrmSession) -> list[Mirror]: + """Get all the enabled mirrors from the DB""" + return list( + session.scalars( + select(Mirror).where(Mirror.enabled == True) # noqa: E712 + ).all() + ) diff --git a/backend/src/mirrors_qa_backend/db/models.py b/backend/src/mirrors_qa_backend/db/models.py index 125095e..e3b5ff0 100644 --- a/backend/src/mirrors_qa_backend/db/models.py +++ b/backend/src/mirrors_qa_backend/db/models.py @@ -4,7 +4,14 @@ from ipaddress import IPv4Address from uuid import UUID -from sqlalchemy import DateTime, Enum, ForeignKey, String, UniqueConstraint, text +from sqlalchemy import ( + DateTime, + Enum, + ForeignKey, + String, + UniqueConstraint, + text, +) from sqlalchemy.dialects.postgresql import ARRAY, INET from sqlalchemy.orm import ( DeclarativeBase, @@ -47,7 +54,19 @@ class Base(MappedAsDataclass, DeclarativeBase): pass +class WorkerCountry(Base): + """Association table for many-to-many relationships between worker and country.""" + + __tablename__ = "worker_country" + worker_id: Mapped[str] = mapped_column(ForeignKey("worker.id"), primary_key=True) + country_code: Mapped[str] = mapped_column( + ForeignKey("country.code"), primary_key=True + ) + + class Country(Base): + """Country where a worker runs tests for a mirror.""" + __tablename__ = "country" code: Mapped[str] = mapped_column( @@ -56,16 +75,13 @@ class Country(Base): name: Mapped[str] # full name of the country (in English) - worker_id: Mapped[str | None] = mapped_column(ForeignKey("worker.id"), init=False) - worker: Mapped[Worker | None] = relationship(back_populates="countries", init=False) - mirrors: Mapped[list[Mirror]] = relationship( - back_populates="country", + workers: Mapped[list[Worker]] = relationship( + back_populates="countries", init=False, - cascade="all, delete-orphan", + secondary=WorkerCountry.__table__, + repr=False, ) - tests: Mapped[list[Test]] = relationship(back_populates="country", init=False) - __table_args__ = (UniqueConstraint("name", "code"),) @@ -86,11 +102,11 @@ class Mirror(Base): as_only: Mapped[bool | None] = mapped_column(default=None) other_countries: Mapped[list[str] | None] = mapped_column(default=None) - country_code: Mapped[str] = mapped_column( - ForeignKey("country.code"), - init=False, + tests: Mapped[list[Test]] = relationship( + back_populates="mirror", init=False, repr=False ) - country: Mapped[Country] = relationship(back_populates="mirrors", init=False) + + __table_args__ = (UniqueConstraint("base_url"),) class Worker(Base): @@ -98,15 +114,22 @@ class Worker(Base): id: Mapped[str] = mapped_column(primary_key=True) # RSA public key in PKCS8 format for generating access tokens required # to make requests to the web server - pubkey_pkcs8: Mapped[str] - pubkey_fingerprint: Mapped[str] + pubkey_pkcs8: Mapped[str] = mapped_column(repr=False) + pubkey_fingerprint: Mapped[str] = mapped_column(repr=False) last_seen_on: Mapped[datetime.datetime] = mapped_column( default_factory=datetime.datetime.now ) - countries: Mapped[list[Country]] = relationship(back_populates="worker", init=False) + countries: Mapped[list[Country]] = relationship( + back_populates="workers", + init=False, + secondary=WorkerCountry.__table__, + cascade="all, delete", + ) - tests: Mapped[list[Test]] = relationship(back_populates="worker", init=False) + tests: Mapped[list[Test]] = relationship( + back_populates="worker", init=False, repr=False + ) class Test(Base): @@ -128,25 +151,30 @@ class Test(Base): ), default=StatusEnum.PENDING, ) + # Base URL of the mirror which the test will be run + mirror_url: Mapped[str | None] = mapped_column( + ForeignKey("mirror.base_url"), init=False, default=None + ) error: Mapped[str | None] = mapped_column(default=None) isp: Mapped[str | None] = mapped_column(default=None) ip_address: Mapped[IPv4Address | None] = mapped_column(default=None) # autonomous system based on IP asn: Mapped[str | None] = mapped_column(default=None) - country_code: Mapped[str | None] = mapped_column( - ForeignKey("country.code"), - init=False, - default=None, - ) - location: Mapped[str | None] = mapped_column(default=None) # city based on IP - latency: Mapped[int | None] = mapped_column(default=None) # milliseconds + # country to run the test from + country_code: Mapped[str | None] = mapped_column(default=None) + city: Mapped[str | None] = mapped_column(default=None) # city based on IP + latency: Mapped[float | None] = mapped_column(default=None) # milliseconds download_size: Mapped[int | None] = mapped_column(default=None) # bytes - duration: Mapped[int | None] = mapped_column(default=None) # seconds + duration: Mapped[float | None] = mapped_column(default=None) # seconds speed: Mapped[float | None] = mapped_column(default=None) # bytes per second worker_id: Mapped[str | None] = mapped_column( ForeignKey("worker.id"), init=False, default=None ) - worker: Mapped[Worker | None] = relationship(back_populates="tests", init=False) + worker: Mapped[Worker | None] = relationship( + back_populates="tests", init=False, repr=False + ) - country: Mapped[Country | None] = relationship(back_populates="tests", init=False) + mirror: Mapped[Mirror | None] = relationship( + back_populates="tests", init=False, repr=False + ) diff --git a/backend/src/mirrors_qa_backend/db/tests.py b/backend/src/mirrors_qa_backend/db/tests.py index e8e9013..a9325df 100644 --- a/backend/src/mirrors_qa_backend/db/tests.py +++ b/backend/src/mirrors_qa_backend/db/tests.py @@ -6,9 +6,8 @@ from sqlalchemy import UnaryExpression, asc, desc, func, select, update from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend.db import models -from mirrors_qa_backend.db.country import get_country_or_none from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError +from mirrors_qa_backend.db.models import Mirror, Test, Worker from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum from mirrors_qa_backend.settings import Settings @@ -18,11 +17,11 @@ class TestListResult: """Result of query to list tests from the database.""" nb_tests: int - tests: list[models.Test] + tests: list[Test] def filter_test( - test: models.Test, + test: Test, *, worker_id: str | None = None, country_code: str | None = None, @@ -42,10 +41,11 @@ def filter_test( return True -def get_test(session: OrmSession, test_id: UUID) -> models.Test | None: - return session.scalars( - select(models.Test).where(models.Test.id == test_id) - ).one_or_none() +def get_test(session: OrmSession, test_id: UUID) -> Test: + test = session.scalars(select(Test).where(Test.id == test_id)).one_or_none() + if test is None: + raise RecordDoesNotExistError(f"Test with id: {test_id} does not exist.") + return test def list_tests( @@ -85,11 +85,11 @@ def list_tests( # its default in the database which translates to a SQL true i.e we don't # filter based on this argument. query = ( - select(func.count().over().label("total_records"), models.Test) + select(func.count().over().label("total_records"), Test) .where( - (models.Test.worker_id == worker_id) | (worker_id is None), - (models.Test.country_code == country_code) | (country_code is None), - (models.Test.status.in_(statuses)), + (Test.worker_id == worker_id) | (worker_id is None), + (Test.country_code == country_code) | (country_code is None), + (Test.status.in_(statuses)), ) .order_by(*order_by) .offset((page_num - 1) * page_size) @@ -99,52 +99,43 @@ def list_tests( result = TestListResult(nb_tests=0, tests=[]) for total_records, test in session.execute(query).all(): + # Because the SQL window function returns the total_records + # for every row, assign that value to the nb_tests result.nb_tests = total_records result.tests.append(test) return result -def create_or_update_test( +def update_test( session: OrmSession, - test_id: UUID | None = None, + test_id: UUID, *, - worker_id: str | None = None, status: StatusEnum = StatusEnum.PENDING, error: str | None = None, ip_address: IPv4Address | None = None, asn: str | None = None, - country_code: str | None = None, - location: str | None = None, - latency: int | None = None, + city: str | None = None, + latency: float | None = None, download_size: int | None = None, - duration: int | None = None, + duration: float | None = None, speed: float | None = None, started_on: datetime.datetime | None = None, -) -> models.Test: - """Create a test if test_id is None or update the test with test_id""" - if test_id is None: - test = models.Test() - else: - test = get_test(session, test_id) - if test is None: - raise RecordDoesNotExistError(f"Test with id: {test_id} does not exist.") - + isp: str | None = None, +) -> Test: + test = get_test(session, test_id) # If a value is provided, it takes precedence over the default value of the model - test.worker_id = worker_id if worker_id else test.worker_id test.status = status test.error = error if error else test.error test.ip_address = ip_address if ip_address else test.ip_address test.asn = asn if asn else test.asn - test.country = ( - get_country_or_none(session, country_code) if country_code else test.country - ) - test.location = location if location else test.location + test.city = city if city else test.city test.latency = latency if latency else test.latency test.download_size = download_size if download_size else test.download_size test.duration = duration if duration else test.duration test.speed = speed if speed else test.speed test.started_on = started_on if started_on else test.started_on + test.isp = isp if isp else test.isp session.add(test) session.flush() @@ -155,51 +146,31 @@ def create_or_update_test( def create_test( session: OrmSession, *, - worker_id: str | None = None, - status: StatusEnum = StatusEnum.PENDING, - error: str | None = None, - ip_address: IPv4Address | None = None, - asn: str | None = None, - country_code: str | None = None, - location: str | None = None, - latency: int | None = None, - download_size: int | None = None, - duration: int | None = None, - speed: float | None = None, - started_on: datetime.datetime | None = None, -) -> models.Test: - return create_or_update_test( - session, - test_id=None, - worker_id=worker_id, - status=status, - error=error, - ip_address=ip_address, - asn=asn, - country_code=country_code, - location=location, - latency=latency, - download_size=download_size, - duration=duration, - speed=speed, - started_on=started_on, - ) + worker: Worker, + mirror: Mirror, + country_code: str, +) -> Test: + test = Test(status=StatusEnum.PENDING, country_code=country_code) + test.worker = worker + test.mirror = mirror + + session.add(test) + session.flush() + return test -def expire_tests( - session: OrmSession, interval: datetime.timedelta -) -> list[models.Test]: +def expire_tests(session: OrmSession, interval: datetime.timedelta) -> list[Test]: """Change the status of PENDING tests created before the interval to MISSED""" end = datetime.datetime.now() - interval begin = datetime.datetime.fromtimestamp(0) return list( session.scalars( - update(models.Test) + update(Test) .where( - models.Test.requested_on.between(begin, end), - models.Test.status == StatusEnum.PENDING, + Test.requested_on.between(begin, end), + Test.status == StatusEnum.PENDING, ) .values(status=StatusEnum.MISSED) - .returning(models.Test) + .returning(Test) ).all() ) diff --git a/backend/src/mirrors_qa_backend/db/worker.py b/backend/src/mirrors_qa_backend/db/worker.py index 354bb6f..ba66692 100644 --- a/backend/src/mirrors_qa_backend/db/worker.py +++ b/backend/src/mirrors_qa_backend/db/worker.py @@ -10,14 +10,23 @@ serialize_public_key, ) from mirrors_qa_backend.db.country import get_countries -from mirrors_qa_backend.db.exceptions import DuplicatePrimaryKeyError +from mirrors_qa_backend.db.exceptions import ( + DuplicatePrimaryKeyError, + RecordDoesNotExistError, +) from mirrors_qa_backend.db.models import Worker -def get_worker(session: OrmSession, worker_id: str) -> Worker | None: +def get_worker_or_none(session: OrmSession, worker_id: str) -> Worker | None: return session.scalars(select(Worker).where(Worker.id == worker_id)).one_or_none() +def get_worker(session: OrmSession, worker_id: str) -> Worker: + if worker := get_worker_or_none(session, worker_id): + return worker + raise RecordDoesNotExistError(f"Worker with id: {worker_id} does not exist.") + + def create_worker( session: OrmSession, worker_id: str, @@ -25,7 +34,7 @@ def create_worker( private_key: RSAPrivateKey, ) -> Worker: """Creates a worker using RSA private key.""" - if get_worker(session, worker_id) is not None: + if get_worker_or_none(session, worker_id) is not None: raise DuplicatePrimaryKeyError(f"A worker with id {worker_id} already exists.") public_key = generate_public_key(private_key) @@ -35,15 +44,27 @@ def create_worker( pubkey_pkcs8=public_key_pkcs8, pubkey_fingerprint=get_public_key_fingerprint(public_key), ) - session.add(worker) - for db_country in get_countries(session, *country_codes): - db_country.worker_id = worker_id - session.add(db_country) + update_worker_countries(session, worker, country_codes) + + return worker + +def update_worker_countries( + session: OrmSession, worker: Worker, country_codes: list[str] +) -> Worker: + worker.countries = get_countries(session, country_codes) + session.add(worker) return worker +def update_worker( + session: OrmSession, worker_id: str, country_codes: list[str] +) -> Worker: + worker = get_worker(session, worker_id) + return update_worker_countries(session, worker, country_codes) + + def get_workers_last_seen_in_range( session: OrmSession, begin: datetime.datetime, end: datetime.datetime ) -> list[Worker]: @@ -59,7 +80,7 @@ def get_workers_last_seen_in_range( def get_idle_workers(session: OrmSession, interval: datetime.timedelta) -> list[Worker]: end = datetime.datetime.now() - interval - begin = datetime.datetime(1970, 1, 1) + begin = datetime.datetime.fromtimestamp(0) return get_workers_last_seen_in_range(session, begin, end) diff --git a/backend/src/mirrors_qa_backend/entrypoint.py b/backend/src/mirrors_qa_backend/entrypoint.py index 0ca73ef..3f68316 100644 --- a/backend/src/mirrors_qa_backend/entrypoint.py +++ b/backend/src/mirrors_qa_backend/entrypoint.py @@ -8,16 +8,17 @@ from mirrors_qa_backend.__about__ import __version__ from mirrors_qa_backend.cli.mirrors import update_mirrors from mirrors_qa_backend.cli.scheduler import main as start_scheduler -from mirrors_qa_backend.cli.worker import create_worker +from mirrors_qa_backend.cli.worker import create_worker, update_worker from mirrors_qa_backend.settings.scheduler import SchedulerSettings UPDATE_MIRRORS_CLI = "update-mirrors" CREATE_WORKER_CLI = "create-worker" +UPDATE_WORKER_CLI = "update-worker" SCHEDULER_CLI = "scheduler" def main(): - # The program is split into a number of sub-commands which each sbu-command + # The program is split into a number of sub-commands with each sub-command # performing different function and requring different different kinds of # command line arguments parser = argparse.ArgumentParser() @@ -65,19 +66,22 @@ def main(): metavar="duration", ) - create_worker_cli = subparsers.add_parser( - CREATE_WORKER_CLI, help="Create a new worker." - ) - create_worker_cli.add_argument( + # Parser for holding shared arguments for worker sub-commands + worker_parser = argparse.ArgumentParser(add_help=False) + worker_parser.add_argument( "worker_id", help="ID of the worker.", metavar="worker-id" ) - create_worker_cli.add_argument( + worker_parser.add_argument( "--countries", help="Comma-seperated country codes each in ISO 3166-1 alpha-2 format.", type=lambda countries: countries.split(","), dest="countries", metavar="codes", ) + + create_worker_cli = subparsers.add_parser( + CREATE_WORKER_CLI, help="Create a new worker.", parents=[worker_parser] + ) create_worker_cli.add_argument( "private_key_file", metavar="private-key-file", @@ -87,24 +91,45 @@ def main(): help="RSA private key file (default: stdin).", ) + subparsers.add_parser( + UPDATE_WORKER_CLI, help="Update a worker", parents=[worker_parser] + ) + args = parser.parse_args() if args.verbose: logger.setLevel(logging.DEBUG) if args.cli_name == UPDATE_MIRRORS_CLI: - return update_mirrors() + try: + update_mirrors() + except Exception as exc: + logger.error(f"error while updating mirrors: {exc!s}") + sys.exit(1) elif args.cli_name == SCHEDULER_CLI: - return start_scheduler( + start_scheduler( args.scheduler_sleep_seconds, args.expire_tests_since, args.workers_since, ) elif args.cli_name == CREATE_WORKER_CLI: - return create_worker( - args.worker_id, - bytes(args.private_key_file.read(), encoding="ascii"), - args.countries if args.countries else [], - ) + try: + create_worker( + args.worker_id, + bytes(args.private_key_file.read(), encoding="ascii"), + args.countries if args.countries else [], + ) + except Exception as exc: + logger.error(f"error while creating worker: {exc!s}") + sys.exit(1) + elif args.cli_name == UPDATE_WORKER_CLI: + try: + update_worker( + args.worker_id, + args.countries if args.countries else [], + ) + except Exception as exc: + logger.error(f"error while updating worker: {exc!s}") + sys.exit(1) else: args.print_help() diff --git a/backend/src/mirrors_qa_backend/exceptions.py b/backend/src/mirrors_qa_backend/exceptions.py index 0c2f85c..f3d704a 100644 --- a/backend/src/mirrors_qa_backend/exceptions.py +++ b/backend/src/mirrors_qa_backend/exceptions.py @@ -17,3 +17,9 @@ class PEMPublicKeyLoadError(Exception): """Unable to deserialize a public key from PEM encoded data""" pass + + +class LocationsRequestError(RequestException): + """A network error occurred while fetching locations from the locations URL""" + + pass diff --git a/backend/src/mirrors_qa_backend/extract.py b/backend/src/mirrors_qa_backend/extract.py index ee8dc63..be81bed 100644 --- a/backend/src/mirrors_qa_backend/extract.py +++ b/backend/src/mirrors_qa_backend/extract.py @@ -1,12 +1,11 @@ from typing import Any from urllib.parse import urlsplit -import pycountry import requests from bs4 import BeautifulSoup, NavigableString from bs4.element import Tag -from mirrors_qa_backend import logger, schemas +from mirrors_qa_backend import schemas from mirrors_qa_backend.exceptions import MirrorsExtractError, MirrorsRequestError from mirrors_qa_backend.settings import Settings @@ -54,22 +53,11 @@ def is_country_row(tag: Tag) -> bool: ).netloc # pyright: ignore [reportUnknownMemberType] if hostname in Settings.MIRRORS_EXCLUSION_LIST: continue - country_name = row.find("img").next_sibling.text.strip() - try: - country: Any = pycountry.countries.search_fuzzy(country_name)[0] - except LookupError: - logger.error(f"Could not get information for country: {country_name}") - continue - else: - mirrors.append( - schemas.Mirror( - id=hostname, - base_url=base_url, - enabled=True, - country=schemas.Country( - code=country.alpha_2.lower(), - name=country.name, - ), - ) + mirrors.append( + schemas.Mirror( + id=hostname, + base_url=base_url, + enabled=True, ) + ) return mirrors diff --git a/backend/src/mirrors_qa_backend/migrations/versions/091da85970cb_add_mirror_base_url_to_test.py b/backend/src/mirrors_qa_backend/migrations/versions/091da85970cb_add_mirror_base_url_to_test.py new file mode 100644 index 0000000..a8f5c24 --- /dev/null +++ b/backend/src/mirrors_qa_backend/migrations/versions/091da85970cb_add_mirror_base_url_to_test.py @@ -0,0 +1,38 @@ +"""add mirror base url to test + +Revision ID: 091da85970cb +Revises: 88e49e681048 +Create Date: 2024-07-05 13:54:38.206076 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "091da85970cb" +down_revision = "88e49e681048" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint(op.f("uq_mirror_base_url"), "mirror", ["base_url"]) + op.add_column("test", sa.Column("mirror_url", sa.String(), nullable=True)) + op.create_foreign_key( + op.f("fk_test_mirror_url_mirror"), + "test", + "mirror", + ["mirror_url"], + ["base_url"], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f("fk_test_mirror_url_mirror"), "test", type_="foreignkey") + op.drop_column("test", "mirror_url") + op.drop_constraint(op.f("uq_mirror_base_url"), "mirror", type_="unique") + # ### end Alembic commands ### diff --git a/backend/src/mirrors_qa_backend/migrations/versions/17d587447299_remove_locations_table.py b/backend/src/mirrors_qa_backend/migrations/versions/17d587447299_remove_locations_table.py new file mode 100644 index 0000000..6f93d3f --- /dev/null +++ b/backend/src/mirrors_qa_backend/migrations/versions/17d587447299_remove_locations_table.py @@ -0,0 +1,46 @@ +"""remove locations table + +Revision ID: 17d587447299 +Revises: 1e455d030d80 +Create Date: 2024-07-17 04:39:53.486384 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "17d587447299" +down_revision = "1e455d030d80" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("location") + op.drop_constraint("fk_mirror_country_code_country", "mirror", type_="foreignkey") + op.drop_column("mirror", "country_code") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "mirror", + sa.Column("country_code", sa.VARCHAR(), autoincrement=False, nullable=False), + ) + op.create_foreign_key( + "fk_mirror_country_code_country", + "mirror", + "country", + ["country_code"], + ["code"], + ) + op.create_table( + "location", + sa.Column("code", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint("code", name="pk_location"), + ) + # ### end Alembic commands ### diff --git a/backend/src/mirrors_qa_backend/migrations/versions/1e455d030d80_set_up_many_to_many_table_for_worker_.py b/backend/src/mirrors_qa_backend/migrations/versions/1e455d030d80_set_up_many_to_many_table_for_worker_.py new file mode 100644 index 0000000..897374a --- /dev/null +++ b/backend/src/mirrors_qa_backend/migrations/versions/1e455d030d80_set_up_many_to_many_table_for_worker_.py @@ -0,0 +1,54 @@ +"""set up many-to-many table for worker and country + +Revision ID: 1e455d030d80 +Revises: 40f7e2866f5d +Create Date: 2024-07-15 13:33:11.070559 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "1e455d030d80" +down_revision = "40f7e2866f5d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "worker_country", + sa.Column("worker_id", sa.String(), nullable=False), + sa.Column("country_code", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["country_code"], + ["country.code"], + name=op.f("fk_worker_country_country_code_country"), + ), + sa.ForeignKeyConstraint( + ["worker_id"], + ["worker.id"], + name=op.f("fk_worker_country_worker_id_worker"), + ), + sa.PrimaryKeyConstraint( + "worker_id", "country_code", name=op.f("pk_worker_country") + ), + ) + op.drop_constraint("fk_country_worker_id_worker", "country", type_="foreignkey") + op.drop_column("country", "worker_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "country", + sa.Column("worker_id", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "fk_country_worker_id_worker", "country", "worker", ["worker_id"], ["id"] + ) + op.drop_table("worker_country") + # ### end Alembic commands ### diff --git a/backend/src/mirrors_qa_backend/migrations/versions/40f7e2866f5d_add_location_for_tests.py b/backend/src/mirrors_qa_backend/migrations/versions/40f7e2866f5d_add_location_for_tests.py new file mode 100644 index 0000000..28e8f93 --- /dev/null +++ b/backend/src/mirrors_qa_backend/migrations/versions/40f7e2866f5d_add_location_for_tests.py @@ -0,0 +1,35 @@ +"""add location for tests + +Revision ID: 40f7e2866f5d +Revises: e1f384f65fcc +Create Date: 2024-07-08 11:50:01.976211 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "40f7e2866f5d" +down_revision = "e1f384f65fcc" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "location", + sa.Column("code", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("code", name=op.f("pk_location")), + ) + op.alter_column("test", "location", new_column_name="city") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("test", "city", new_column_name="location") + op.drop_table("location") + # ### end Alembic commands ### diff --git a/backend/src/mirrors_qa_backend/migrations/versions/e1f384f65fcc_unlink_country_code_from_test_to_country.py b/backend/src/mirrors_qa_backend/migrations/versions/e1f384f65fcc_unlink_country_code_from_test_to_country.py new file mode 100644 index 0000000..2ba274c --- /dev/null +++ b/backend/src/mirrors_qa_backend/migrations/versions/e1f384f65fcc_unlink_country_code_from_test_to_country.py @@ -0,0 +1,29 @@ +"""unlink country_code from test to country + +Revision ID: e1f384f65fcc +Revises: e6deb0a57a6a +Create Date: 2024-07-05 17:40:39.503348 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "e1f384f65fcc" +down_revision = "e6deb0a57a6a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("fk_test_country_code_country", "test", type_="foreignkey") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_foreign_key( + "fk_test_country_code_country", "test", "country", ["country_code"], ["code"] + ) + # ### end Alembic commands ### diff --git a/backend/src/mirrors_qa_backend/migrations/versions/e6deb0a57a6a_use_floating_point_number_for_speed_.py b/backend/src/mirrors_qa_backend/migrations/versions/e6deb0a57a6a_use_floating_point_number_for_speed_.py new file mode 100644 index 0000000..ba6eed8 --- /dev/null +++ b/backend/src/mirrors_qa_backend/migrations/versions/e6deb0a57a6a_use_floating_point_number_for_speed_.py @@ -0,0 +1,54 @@ +"""use floating-point number for speed metrics + +Revision ID: e6deb0a57a6a +Revises: 091da85970cb +Create Date: 2024-07-05 17:27:39.442279 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "e6deb0a57a6a" +down_revision = "091da85970cb" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "test", + "latency", + existing_type=sa.INTEGER(), + type_=sa.Float(), + existing_nullable=True, + ) + op.alter_column( + "test", + "duration", + existing_type=sa.INTEGER(), + type_=sa.Float(), + existing_nullable=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "test", + "duration", + existing_type=sa.Float(), + type_=sa.INTEGER(), + existing_nullable=True, + ) + op.alter_column( + "test", + "latency", + existing_type=sa.Float(), + type_=sa.INTEGER(), + existing_nullable=True, + ) + # ### end Alembic commands ### diff --git a/backend/src/mirrors_qa_backend/routes/auth.py b/backend/src/mirrors_qa_backend/routes/auth.py index b8e0950..26ae002 100644 --- a/backend/src/mirrors_qa_backend/routes/auth.py +++ b/backend/src/mirrors_qa_backend/routes/auth.py @@ -7,6 +7,7 @@ from mirrors_qa_backend import logger from mirrors_qa_backend.cryptography import verify_signed_message +from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError from mirrors_qa_backend.db.worker import get_worker from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError from mirrors_qa_backend.routes.dependencies import DbSession @@ -56,9 +57,10 @@ def authenticate_worker( ) # verify worker with worker_id exists in database - db_worker = get_worker(session, worker_id) - if db_worker is None: - raise UnauthorizedError() + try: + db_worker = get_worker(session, worker_id) + except RecordDoesNotExistError as exc: + raise UnauthorizedError() from exc # verify signature of message with worker's public keys try: diff --git a/backend/src/mirrors_qa_backend/routes/dependencies.py b/backend/src/mirrors_qa_backend/routes/dependencies.py index 963c811..eecf7a6 100644 --- a/backend/src/mirrors_qa_backend/routes/dependencies.py +++ b/backend/src/mirrors_qa_backend/routes/dependencies.py @@ -10,6 +10,7 @@ from mirrors_qa_backend import schemas from mirrors_qa_backend.db import gen_dbsession, models +from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError from mirrors_qa_backend.db.tests import get_test as db_get_test from mirrors_qa_backend.db.worker import get_worker from mirrors_qa_backend.routes.http_errors import NotFoundError, UnauthorizedError @@ -40,9 +41,10 @@ def get_current_worker( # At this point, we know that the JWT is all OK and we can # trust the data in it. We extract the worker_id from the claims - db_worker = get_worker(session, claims.subject) - if db_worker is None: - raise UnauthorizedError() + try: + db_worker = get_worker(session, claims.subject) + except RecordDoesNotExistError as exc: + raise UnauthorizedError() from exc return db_worker @@ -51,9 +53,10 @@ def get_current_worker( def get_test(session: DbSession, test_id: Annotated[UUID4, Path()]) -> models.Test: """Fetches the test specified in the request.""" - test = db_get_test(session, test_id) - if test is None: - raise NotFoundError(f"Test with id {test_id} does not exist.") + try: + test = db_get_test(session, test_id) + except RecordDoesNotExistError as exc: + raise NotFoundError(f"{exc!s}") from exc return test diff --git a/backend/src/mirrors_qa_backend/routes/tests.py b/backend/src/mirrors_qa_backend/routes/tests.py index a463368..fa47a64 100644 --- a/backend/src/mirrors_qa_backend/routes/tests.py +++ b/backend/src/mirrors_qa_backend/routes/tests.py @@ -4,8 +4,8 @@ from fastapi import status as status_codes from mirrors_qa_backend import schemas -from mirrors_qa_backend.db.tests import create_or_update_test from mirrors_qa_backend.db.tests import list_tests as db_list_tests +from mirrors_qa_backend.db.tests import update_test as update_test_model from mirrors_qa_backend.db.worker import update_worker_last_seen from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum from mirrors_qa_backend.routes.dependencies import ( @@ -88,19 +88,19 @@ def update_test( ) -> Test: data = update.model_dump(exclude_unset=True) body = schemas.UpdateTestModel().model_copy(update=data) - updated_test = create_or_update_test( + updated_test = update_test_model( session, test_id=test.id, - worker_id=current_worker.id, status=body.status, error=body.error, ip_address=body.ip_address, asn=body.asn, - country_code=body.country_code, - location=body.location, + city=body.city, latency=body.latency, download_size=body.download_size, duration=body.duration, + isp=body.isp, + started_on=body.started_on, speed=body.speed, ) update_worker_last_seen(session, current_worker) diff --git a/backend/src/mirrors_qa_backend/schemas.py b/backend/src/mirrors_qa_backend/schemas.py index 33583b5..19aefaf 100644 --- a/backend/src/mirrors_qa_backend/schemas.py +++ b/backend/src/mirrors_qa_backend/schemas.py @@ -12,11 +12,6 @@ class BaseModel(pydantic.BaseModel): model_config = ConfigDict(use_enum_values=True, from_attributes=True) -class Country(BaseModel): - code: str # two-letter country codes as defined in ISO 3166-1 - name: str # full name of country (in English) - - class Mirror(BaseModel): id: str # hostname of a mirror URL base_url: str @@ -30,7 +25,6 @@ class Mirror(BaseModel): region_only: bool | None = None as_only: bool | None = None other_countries: list[str] | None = None - country: Country class UpdateTestModel(BaseModel): @@ -39,11 +33,10 @@ class UpdateTestModel(BaseModel): isp: str | None = None ip_address: IPv4Address | None = None asn: str | None = None - country_code: str | None = None - location: str | None = None - latency: int | None = None + city: str | None = None + latency: float | None = None download_size: int | None = None - duration: int | None = None + duration: float | None = None speed: float | None = None status: StatusEnum = StatusEnum.PENDING @@ -51,6 +44,8 @@ class UpdateTestModel(BaseModel): class Test(UpdateTestModel): id: UUID4 requested_on: datetime.datetime + country_code: str | None = None # country to run the test from + mirror_url: str | None # base url of the mirror to run the test class Paginator(BaseModel): diff --git a/backend/src/mirrors_qa_backend/serializer.py b/backend/src/mirrors_qa_backend/serializer.py index fd4d907..621dfdf 100644 --- a/backend/src/mirrors_qa_backend/serializer.py +++ b/backend/src/mirrors_qa_backend/serializer.py @@ -13,11 +13,12 @@ def serialize_test(test: models.Test) -> schemas.Test: ip_address=test.ip_address, asn=test.asn, country_code=test.country_code, - location=test.location, + city=test.city, latency=test.latency, download_size=test.download_size, duration=test.duration, speed=test.speed, + mirror_url=test.mirror_url, ) @@ -35,5 +36,4 @@ def serialize_mirror(mirror: models.Mirror) -> schemas.Mirror: region_only=mirror.region_only, as_only=mirror.as_only, other_countries=mirror.other_countries, - country=schemas.Country(code=mirror.country.code, name=mirror.country.name), ) diff --git a/backend/tests/cli/__init__.py b/backend/tests/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/cli/test_worker.py b/backend/tests/cli/test_worker.py new file mode 100644 index 0000000..d495e9d --- /dev/null +++ b/backend/tests/cli/test_worker.py @@ -0,0 +1,24 @@ +from contextlib import nullcontext as does_not_raise +from typing import Any + +import pytest + +from mirrors_qa_backend.cli.worker import create_worker + + +@pytest.mark.parametrize( + ["worker_id", "country_codes", "expectation"], + [ + ("test", ["ng", "ca", "fr"], does_not_raise()), + ("test", ["invalid country code"], pytest.raises(ValueError)), + ("test", ["zz", "vv"], pytest.raises(ValueError)), + ], +) +def test_create_worker( + private_key_data: bytes, + worker_id: str, + country_codes: list[str], + expectation: Any, +): + with expectation: + create_worker(worker_id, private_key_data, country_codes) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 2ca6a5d..53bbed2 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -4,7 +4,6 @@ from typing import Any import paramiko -import pycountry import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -13,14 +12,15 @@ from faker.providers import DynamicProvider from sqlalchemy.orm import Session as OrmSession +from mirrors_qa_backend import schemas from mirrors_qa_backend.cryptography import sign_message from mirrors_qa_backend.db import Session -from mirrors_qa_backend.db.country import get_country_or_none -from mirrors_qa_backend.db.models import Base, Country, Test, Worker +from mirrors_qa_backend.db.models import Base, Mirror, Test, Worker from mirrors_qa_backend.enums import StatusEnum +from mirrors_qa_backend.serializer import serialize_mirror -@pytest.fixture +@pytest.fixture(autouse=True) def dbsession() -> Generator[OrmSession, None, None]: with Session.begin() as session: # Ensure we are starting with an empty database @@ -61,7 +61,10 @@ def data_gen(faker: Faker) -> Faker: @pytest.fixture def tests( - dbsession: OrmSession, data_gen: Faker, worker: Worker, request: Any + dbsession: OrmSession, + data_gen: Faker, + worker: Worker, + request: Any, ) -> list[Test]: """Adds tests to the database using the num_test mark.""" mark = request.node.get_closest_marker("num_tests") @@ -78,18 +81,7 @@ def tests( selected_country_code = ( country_code if country_code else data_gen.test_country_code() ) - if country := get_country_or_none(dbsession, selected_country_code): - test.country = country - else: - country = Country( - code=selected_country_code.lower(), - name=pycountry.countries.get( - alpha_2=selected_country_code - ).name, # pyright: ignore [reportOptionalMemberAccess] - ) - dbsession.add(country) - test.country = country - + test.country_code = selected_country_code test.worker = worker dbsession.add(test) @@ -143,3 +135,46 @@ def x_sshauth_signature(private_key: RSAPrivateKey, auth_message: str) -> str: """Sign a message using RSA private key and encode it in base64""" signature = sign_message(private_key, bytes(auth_message, encoding="ascii")) return base64.b64encode(signature).decode() + + +@pytest.fixture +def db_mirror(dbsession: OrmSession) -> Mirror: + mirror = Mirror( + id="mirror-sites-in.mblibrary.info", + base_url="https://mirror-sites-in.mblibrary.info/mirror-sites/download.kiwix.org/", + enabled=True, + region=None, + asn=None, + score=None, + latitude=None, + longitude=None, + country_only=None, + region_only=None, + as_only=None, + other_countries=None, + ) + dbsession.add(mirror) + return mirror + + +@pytest.fixture +def schema_mirror(db_mirror: Mirror) -> schemas.Mirror: + return serialize_mirror(db_mirror) + + +@pytest.fixture +def new_schema_mirror() -> schemas.Mirror: + return schemas.Mirror( + id="mirrors.dotsrc.org", + base_url="https://mirrors.dotsrc.org/kiwix/", + enabled=True, + region=None, + asn=None, + score=None, + latitude=None, + longitude=None, + country_only=None, + region_only=None, + as_only=None, + other_countries=None, + ) diff --git a/backend/tests/db/__init__.py b/backend/tests/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/db/test_mirrors.py b/backend/tests/db/test_mirrors.py index 62f138d..c40d489 100644 --- a/backend/tests/db/test_mirrors.py +++ b/backend/tests/db/test_mirrors.py @@ -9,53 +9,6 @@ from mirrors_qa_backend.serializer import serialize_mirror -@pytest.fixture(scope="session") -def db_mirror() -> models.Mirror: - mirror = models.Mirror( - id="mirror-sites-in.mblibrary.info", - base_url="https://mirror-sites-in.mblibrary.info/mirror-sites/download.kiwix.org/", - enabled=True, - region=None, - asn=None, - score=None, - latitude=None, - longitude=None, - country_only=None, - region_only=None, - as_only=None, - other_countries=None, - ) - mirror.country = models.Country(code="in", name="India") - return mirror - - -@pytest.fixture(scope="session") -def schema_mirror(db_mirror: models.Mirror) -> schemas.Mirror: - return serialize_mirror(db_mirror) - - -@pytest.fixture(scope="session") -def new_schema_mirror() -> schemas.Mirror: - return schemas.Mirror( - id="mirrors.dotsrc.org", - base_url="https://mirrors.dotsrc.org/kiwix/", - enabled=True, - region=None, - asn=None, - score=None, - latitude=None, - longitude=None, - country_only=None, - region_only=None, - as_only=None, - other_countries=None, - country=schemas.Country( - code="dk", - name="Denmark", - ), - ) - - def test_db_empty(dbsession: OrmSession): assert count_from_stmt(dbsession, select(models.Country)) == 0 @@ -76,10 +29,8 @@ def test_raises_empty_mirrors_error(dbsession: OrmSession): def test_register_new_mirror( dbsession: OrmSession, schema_mirror: schemas.Mirror, - db_mirror: models.Mirror, new_schema_mirror: schemas.Mirror, ): - dbsession.add(db_mirror) result = create_or_update_mirror_status( dbsession, [schema_mirror, new_schema_mirror] ) @@ -88,26 +39,19 @@ def test_register_new_mirror( def test_disable_old_mirror( dbsession: OrmSession, - db_mirror: models.Mirror, + db_mirror: models.Mirror, # noqa: ARG001 [pytest fixture that saves a mirror] new_schema_mirror: schemas.Mirror, ): - dbsession.add(db_mirror) result = create_or_update_mirror_status(dbsession, [new_schema_mirror]) assert result.nb_mirrors_disabled == 1 -def test_no_mirrors_disabled( - dbsession: OrmSession, db_mirror: models.Mirror, schema_mirror: schemas.Mirror -): - dbsession.add(db_mirror) +def test_no_mirrors_disabled(dbsession: OrmSession, schema_mirror: schemas.Mirror): result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_disabled == 0 -def test_no_mirrors_added( - dbsession: OrmSession, db_mirror: models.Mirror, schema_mirror: schemas.Mirror -): - dbsession.add(db_mirror) +def test_no_mirrors_added(dbsession: OrmSession, schema_mirror: schemas.Mirror): result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_added == 0 @@ -130,7 +74,6 @@ def test_re_enable_existing_mirror( as_only=None, other_countries=None, ) - db_mirror.country = models.Country(code="dk", name="Denmark") dbsession.add(db_mirror) # Update the status of the mirror diff --git a/backend/tests/db/test_tests.py b/backend/tests/db/test_tests.py index 4adfbc5..c0adfbe 100644 --- a/backend/tests/db/test_tests.py +++ b/backend/tests/db/test_tests.py @@ -1,4 +1,5 @@ import datetime +import uuid from ipaddress import IPv4Address import pytest @@ -6,24 +7,48 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.db import models +from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError from mirrors_qa_backend.db.tests import ( - create_or_update_test, + create_test, expire_tests, filter_test, get_test, list_tests, + update_test, ) from mirrors_qa_backend.enums import StatusEnum +def test_test_does_not_exist(dbsession: OrmSession): + test_id = uuid.uuid4() + with pytest.raises(RecordDoesNotExistError): + get_test(dbsession, test_id) + + +def test_create_test( + dbsession: OrmSession, db_mirror: models.Mirror, worker: models.Worker +): + test_location = "ng" + test = create_test( + dbsession, + worker=worker, + country_code=test_location, + mirror=db_mirror, + ) + assert test.worker_id == worker.id + assert test.status == StatusEnum.PENDING + assert test.mirror_url == db_mirror.base_url + assert test.country_code == test_location + + @pytest.mark.num_tests(1) def test_get_test(dbsession: OrmSession, tests: list[models.Test]): test = tests[0] result = get_test(dbsession, test.id) - assert result is not None assert result.id == test.id +@pytest.mark.num_tests(1, status=StatusEnum.PENDING, country_code="us") @pytest.mark.parametrize( ["worker_id", "country_code", "statuses", "expected"], [ @@ -36,13 +61,13 @@ def test_get_test(dbsession: OrmSession, tests: list[models.Test]): ) def test_basic_filter( *, - dbsession: OrmSession, + tests: list[models.Test], worker_id: str | None, country_code: str | None, statuses: list[StatusEnum] | None, expected: bool, ): - test = create_or_update_test(dbsession, status=StatusEnum.PENDING) + test = tests[0] assert ( filter_test( test, worker_id=worker_id, country_code=country_code, statuses=statuses @@ -97,7 +122,7 @@ def test_update_test(dbsession: OrmSession, tests: list[models.Test], data_gen: "started_on": data_gen.date_time(datetime.UTC), "latency": latency, } - updated_test = create_or_update_test(dbsession, test_id, **update_values) # type: ignore + updated_test = update_test(dbsession, test_id, **update_values) # type: ignore for key, value in update_values.items(): if hasattr(updated_test, key): assert getattr(updated_test, key) == value diff --git a/backend/tests/db/test_worker.py b/backend/tests/db/test_worker.py index 8f52c8c..81843ae 100644 --- a/backend/tests/db/test_worker.py +++ b/backend/tests/db/test_worker.py @@ -1,8 +1,10 @@ +import pytest from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend.db.models import Country -from mirrors_qa_backend.db.worker import create_worker +from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError +from mirrors_qa_backend.db.models import Country, Worker +from mirrors_qa_backend.db.worker import create_worker, get_worker def test_create_worker(dbsession: OrmSession, private_key: RSAPrivateKey): @@ -24,3 +26,12 @@ def test_create_worker(dbsession: OrmSession, private_key: RSAPrivateKey): assert len(new_worker.countries) == len(countries) assert "BEGIN PUBLIC KEY" in new_worker.pubkey_pkcs8 assert "END PUBLIC KEY" in new_worker.pubkey_pkcs8 + + +def test_worker_does_not_exist(dbsession: OrmSession): + with pytest.raises(RecordDoesNotExistError): + get_worker(dbsession, "does not exist") + + +def test_get_worker(dbsession: OrmSession, worker: Worker): + assert get_worker(dbsession, worker.id).id == worker.id diff --git a/dev/README.md b/dev/README.md index 2a27504..2fb7850 100644 --- a/dev/README.md +++ b/dev/README.md @@ -5,8 +5,8 @@ This docker-compose configuration to be used **only** for development purpose. ### backend This container is a backend web server, linked to its database. -It provides a command `mirrors-qa-backend` to simplify tasks like updating of mirrors. -Run `mirrors-qa-backend --help` in the container to see the sub-commands and options. +It provides sub-commands to simplify tasks like updating of mirrors. +Run `mirrors-qa-backend --help` in the container to see the various sub-commands and options. ### postgresqldb @@ -64,8 +64,10 @@ docker compose --profile worker up --build The key name `id_rsa` is used as a bind mount in the compose file. - Assuming the backend service is up (`docker compose up backend`), create a worker and assign them a list of countries to test for. - If an existing worker is responsible for those countries, they are assigned to the new worker. In this example, we create a worker - named `test`. + If no countries are provided, all available countries in the DB wiil be assigned to the worker. You can update the countries using `mirrors-qa-backend update-worker`. + + In this example, we create a worker named `test` to test for mirrors in France, United States and Canada using the private key file + named `id_rsa`. ```sh docker exec -i mirrors-qa-backend mirrors-qa-backend create-worker --countries=us,fr,ca test < ./id_rsa ``` @@ -114,6 +116,7 @@ The `backend` code houses the `scheduler` and the `RESTful API`. The following e - `WIREGUARD_HEALTHCHECK_TIMEOUT_SECONDS` - `WIREGUARD_HEALTHCHECK_RETRIES` - `TASK_WORKER_IMAGE` +- `TEST_FILE_PATH`: location of file to run download speed test ## task-worker diff --git a/dev/docker-compose.yaml b/dev/docker-compose.yaml index 8cf6e34..9137b01 100644 --- a/dev/docker-compose.yaml +++ b/dev/docker-compose.yaml @@ -14,6 +14,8 @@ services: interval: 10s timeout: 5s retries: 3 + ports: + - 5432 backend: depends_on: postgresdb: @@ -35,16 +37,16 @@ services: retries: 3 scheduler: depends_on: - postgresdb: + backend: condition: service_healthy build: context: ../backend container_name: mirrors-qa-scheduler environment: POSTGRES_URI: postgresql+psycopg://mirrors_qa:mirrors_qa@postgresdb:5432/mirrors_qa - IDLE_WORKER_DURATION: 1d - SCHEDULER_SLEEP_DURATION: 1m - EXPIRE_TEST_DURATION: 2m + IDLE_WORKER_DURATION: 10m + SCHEDULER_SLEEP_DURATION: 30m + EXPIRE_TEST_DURATION: 1h DEBUG: true command: mirrors-qa-backend scheduler worker-manager: @@ -67,6 +69,7 @@ services: - BACKEND_API_URI=http://backend - SLEEP_DURATION=5m - TASK_WORKER_IMAGE=mirrors-qa-task-worker + - TEST_FILE_PATH=/zim/wikipedia/speedtest_en_blob-mini_2024-05.zim command: mirrors-qa-manager --verbose ${WORKER_ID} task-worker: build: diff --git a/worker/manager/pyproject.toml b/worker/manager/pyproject.toml index 6d39a37..78c2320 100644 --- a/worker/manager/pyproject.toml +++ b/worker/manager/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "cryptography==42.0.8", "paramiko==3.4.0", "humanfriendly==10.0", + "pycountry==24.6.1", ] license = {text = "GPL-3.0-or-later"} classifiers = [ diff --git a/worker/manager/src/mirrors_qa_manager/settings.py b/worker/manager/src/mirrors_qa_manager/settings.py index 8696f55..974d885 100644 --- a/worker/manager/src/mirrors_qa_manager/settings.py +++ b/worker/manager/src/mirrors_qa_manager/settings.py @@ -56,4 +56,8 @@ class Settings: WIREGUARD_HEALTHCHECK_RETRIES = int( getenv("WIREGUARD_HEALTHCHECK_RETRIES", default=3) ) - TASK_WORKER_IMAGE = getenv("TASK_WORKER_IMAGE", mandatory=True) + TASK_WORKER_IMAGE: str = getenv("TASK_WORKER_IMAGE", mandatory=True) + TEST_FILE_PATH: str = getenv( + "TEST_FILE_PATH", + default="/zim/wikipedia/speedtest_en_blob_2024-05.zim", + ) diff --git a/worker/manager/src/mirrors_qa_manager/worker.py b/worker/manager/src/mirrors_qa_manager/worker.py index c11c252..f88d620 100644 --- a/worker/manager/src/mirrors_qa_manager/worker.py +++ b/worker/manager/src/mirrors_qa_manager/worker.py @@ -1,5 +1,6 @@ # pyright: strict, reportMissingTypeStubs=false, reportUnknownMemberType=false, reportOptionalSubscript=false, reportUnknownVariableType=false, reportUnknownArgumentType=false import datetime +import json import shutil import signal import sys @@ -9,6 +10,7 @@ from typing import Any from urllib.parse import urlencode +import pycountry from docker.models.containers import Container from docker.types import Mount @@ -77,6 +79,7 @@ def __init__(self, worker_id: str) -> None: self.wg_up_cmd = ["wg-quick", "up", self.wg_interface] self.task_container_names = set() + # location of the test file on the from the mirror's root self.auth_credentials: None | AuthCredentials = None @@ -167,7 +170,11 @@ def start_wireguard_container(self, image_name: str) -> Container: ) def start_task_container( - self, image_name: str, container_name: str, output_filename: str + self, + image_name: str, + container_name: str, + output_filename: str, + test_file_url: str, ) -> Container: mounts = [ Mount("/data", str(self.get_host_fpath(self.instance_dir)), type="bind") @@ -181,7 +188,7 @@ def start_task_container( }, mounts=mounts, network_mode=f"container:{Settings.WIREGUARD_CONTAINER_NAME}", - command=["mirrors-qa-task", f"--output={output_filename}"], + command=["mirrors-qa-task", test_file_url, f"--output={output_filename}"], ) def query_api( @@ -191,15 +198,14 @@ def query_api( *, payload: dict[str, Any] | None = None, ) -> dict[str, Any]: - if self.auth_credentials is None: self.auth_credentials = authenticate(self.private_key, self.worker_id) - if self.auth_credentials.expires_in < datetime.datetime.now(): + if self.auth_credentials.expires_in <= datetime.datetime.now(): self.auth_credentials = authenticate(self.private_key, self.worker_id) req_headers = { - "Authorization": f"Bearer: {self.auth_credentials.access_token}", + "Authorization": f"Bearer {self.auth_credentials.access_token}", } return query_backend_api( endpoint, @@ -208,6 +214,16 @@ def query_api( payload=payload, ) + def merge_data( + self, *, ip_data: dict[str, Any], metrics_data: dict[str, Any] + ) -> dict[str, Any]: + return { + **metrics_data, + "ip_address": ip_data["ip"], + "city": ip_data["city"], + "isp": ip_data["organization"], + } + def fetch_tests(self) -> list[dict[str, Any]]: logger.debug("Fetching tasks from backend API") @@ -223,6 +239,10 @@ def sleep(self) -> None: logger.info(f"Sleeping for {Settings.SLEEP_SECONDS}s") time.sleep(Settings.SLEEP_SECONDS) + def get_country_code(self, country_name: str) -> str: + country: Any = pycountry.countries.search_fuzzy(country_name)[0] + return country.alpha_2.lower() + def run(self) -> None: logger.info("Starting worker manager.") # Start the wireguard network container @@ -250,7 +270,7 @@ def run(self) -> None: except FileNotFoundError: logger.error( f"Could not find {country_code}.conf for " - f"{test_id}. Skipping." + f"test {test_id}. Skipping test." ) continue except Exception: @@ -307,18 +327,12 @@ def run(self) -> None: # Perform another healthcheck to ensure traffic can go # through. - # TODO: Use the result from the healthcheck call to - # populate the IP-related data as the task container - # doesn't know anything about its network. - # Could also be used to validate that the country config - # is actually for this country in case the 'host' - # wrongly names a config file. logger.info( "Checking if traffic can pass through wireguard interface " f"for test {test_id}, country: {country_code}" ) try: - exec_command( + healthcheck_result = exec_command( self.docker, Settings.WIREGUARD_CONTAINER_NAME, self.wg_healthcheck_cmd, @@ -330,6 +344,20 @@ def run(self) -> None: ) continue + # Ensure the country that this IP belongs to is the same as the + # requested country code. + ip_data = json.loads(healthcheck_result.output.decode("utf-8")) + ip_country_code = self.get_country_code(ip_data["country"]) + + if ip_country_code != country_code: + logger.warning( + "Test expects configuration file for " + f"{country_code}, got {ip_country_code} from host. " + f"Skipping test {test_id} due to wrong " + "configuration file." + ) + continue + # Start container for the task task_container_name = f"task-worker-{test_id}" # It is possible that a container with the existing name already @@ -346,15 +374,21 @@ def run(self) -> None: logger.info( f"Starting container {task_container_name!r} for " - f"processing {test_id}" + f"processing test {test_id}" ) output_fpath = self.instance_dir / f"{test_id}.json" + test_file_url = ( + test["mirror_url"].rstrip("/") + + "/" + + Settings.TEST_FILE_PATH.lstrip("/") + ) try: self.task_container_names.add(task_container_name) self.start_task_container( Settings.TASK_WORKER_IMAGE, task_container_name, output_filename=output_fpath.name, + test_file_url=test_file_url, ) except Exception as exc: logger.error( @@ -373,12 +407,27 @@ def run(self) -> None: else: self.task_container_names.remove(task_container_name) - results = output_fpath.read_bytes() - logger.info(f"Got results from test {test_id}: {results}") - # TODO: Merge the IP data from the healthcheck cmd and the resutls - # an dupload to the Backend API - logger.info(f"Uploading results for {test_id}") - output_fpath.unlink() + results = output_fpath.read_text() + logger.info( + f"Successfully retrieved metrics results for test {test_id}" + ) + payload = self.merge_data( + ip_data=ip_data, + metrics_data=json.loads(results), + ) + logger.info(f"Uploading results for {test_id} to Backend API") + try: + self.query_api( + f"/tests/{test_id}", method="patch", payload=payload + ) + except Exception as exc: + logger.error( + f"error while uploading results to Backend API: {exc!s}" + ) + continue + finally: + output_fpath.unlink() + logger.info(f"Uploaded results for {test_id} to Backend API") except Exception as exc: logger.error(f"error while processing tasks {exc!s}") diff --git a/worker/task/pyproject.toml b/worker/task/pyproject.toml index d5f34c1..eba5b53 100644 --- a/worker/task/pyproject.toml +++ b/worker/task/pyproject.toml @@ -13,6 +13,7 @@ authors = [ keywords = ["mirrors"] dependencies = [ "requests==2.32.3", + "humanfriendly==10.0", ] license = {text = "GPL-3.0-or-later"} classifiers = [ diff --git a/worker/task/src/mirrors_qa_task/__init__.py b/worker/task/src/mirrors_qa_task/__init__.py index 23a49e1..00b7247 100644 --- a/worker/task/src/mirrors_qa_task/__init__.py +++ b/worker/task/src/mirrors_qa_task/__init__.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from mirrors_qa_task.settings import Settings @@ -9,3 +10,14 @@ handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("[%(asctime)s: %(levelname)s] %(message)s")) logger.addHandler(handler) + + +@dataclass +class Metrics: + started_on: str # ISO formatted datetime + status: str # SUCCEEDED|ERRORED + error: str | None # error reason + latency: float # average ping result to netloc of URL + download_size: int # number of bytes of downloaded file + duration: float # number of seconds to complete download + speed: float # bytes per second of the download diff --git a/worker/task/src/mirrors_qa_task/entrypoint.py b/worker/task/src/mirrors_qa_task/entrypoint.py index d386725..59c980b 100644 --- a/worker/task/src/mirrors_qa_task/entrypoint.py +++ b/worker/task/src/mirrors_qa_task/entrypoint.py @@ -1,12 +1,12 @@ import argparse +import json import logging -import sys - -import requests +from dataclasses import asdict from mirrors_qa_task import logger from mirrors_qa_task.__about__ import __version__ from mirrors_qa_task.settings import Settings +from mirrors_qa_task.worker import get_download_metrics def main(): @@ -14,22 +14,21 @@ def main(): parser.add_argument( "--output", "-O", help="Name of file to write results", default="output.json" ) - parser.add_argument("--version", help="Show version and exit", action="store_true") parser.add_argument( "-v", "--verbose", help="Show verbose output", action="store_true" ) + parser.add_argument( + "--version", + help="Show version and exit.", + action="version", + version="%(prog)s " + __version__, + ) + parser.add_argument("url", help="URL of speedtest file.") args = parser.parse_args() - if args.version: + if args.verbose: logger.setLevel(logging.DEBUG) - if args.version: - print(f"Mirrors QA Worker Task: {__version__}") # noqa: T201 - sys.exit(0) + metrics = asdict(get_download_metrics(args.url)) + (Settings.WORKDIR / args.output).write_text(json.dumps(metrics)) - resp = requests.get( - "https://am.i.mullvad.net/json", - timeout=Settings.REQUESTS_TIMEOUT_SECONDS, - ) - resp.raise_for_status() - (Settings.WORKDIR / args.output).write_bytes(resp.content) logger.info(f"Saved data to {args.output}") diff --git a/worker/task/src/mirrors_qa_task/settings.py b/worker/task/src/mirrors_qa_task/settings.py index e450cfe..daf11a3 100644 --- a/worker/task/src/mirrors_qa_task/settings.py +++ b/worker/task/src/mirrors_qa_task/settings.py @@ -2,6 +2,8 @@ from pathlib import Path from typing import Any +from humanfriendly import parse_size, parse_timespan + def getenv(key: str, *, mandatory: bool = False, default: Any = None) -> Any: value = os.getenv(key, default=default) @@ -13,8 +15,14 @@ def getenv(key: str, *, mandatory: bool = False, default: Any = None) -> Any: class Settings: - """Worker task configuration""" + """Task worker configuration""" REQUESTS_TIMEOUT_SECONDS = int(getenv("REQUESTS_TIMEOUT_SECONDS", default=10)) + REQUESTS_MAX_RETRIES = int(getenv("REQUESTS_MAX_RETRIES", default=3)) + REQUESTS_RETRY_SECONDS = parse_timespan( + getenv("REQUESTS_RETRY_DURATION", default="3s") + ) DEBUG = bool(getenv("DEBUG", default=False)) WORKDIR = Path(getenv("WORKDIR", default="/data")).resolve() + USER_AGENT = getenv("USER_AGENT", default="speedtester/robot") + CHUNK_SIZE = parse_size(getenv("CHUNK_SIZE", default="10MiB")) diff --git a/worker/task/src/mirrors_qa_task/worker.py b/worker/task/src/mirrors_qa_task/worker.py index e69de29..b40bb1b 100644 --- a/worker/task/src/mirrors_qa_task/worker.py +++ b/worker/task/src/mirrors_qa_task/worker.py @@ -0,0 +1,93 @@ +import datetime +import time +from pathlib import Path +from urllib.parse import urlsplit + +import humanfriendly +import requests +from requests.exceptions import RequestException + +from mirrors_qa_task import Metrics, logger +from mirrors_qa_task.settings import Settings + + +def fmt(size: float) -> str: + return humanfriendly.format_size(size, binary=True) + + +def get_download_metrics( + url: str, + *, + timeout: int = Settings.REQUESTS_TIMEOUT_SECONDS, + chunk_size: int = Settings.CHUNK_SIZE, + retries: int = Settings.REQUESTS_MAX_RETRIES, + interval: float = Settings.REQUESTS_RETRY_SECONDS, +) -> Metrics: + url_parts = urlsplit(url) + filename = Path(url_parts.path).name + logger.info(f"Downloading {filename} from {url_parts.netloc}") + + error_message = None + attempts = 0 + while attempts <= retries: + started_on = datetime.datetime.now(datetime.UTC).isoformat() + try: + attempts += 1 + start = time.perf_counter() + + resp = requests.get( + url=url, + timeout=timeout, + stream=True, + headers={"User-Agent": Settings.USER_AGENT}, + ) + resp.raise_for_status() + + latency = resp.elapsed.total_seconds() + size = int(resp.headers.get("Content-Length", "0")) + downloaded = 0 + + for data in resp.iter_content(chunk_size=chunk_size): + downloaded += len(data) + percentage = downloaded * 100 / size + print( # noqa: T201 + f"\r{filename}: {percentage:.2f}%: " + f"{fmt(downloaded)}/{fmt(size)}", + flush=True, + end="", + ) + print("\n") # noqa: T201 + + duration = time.perf_counter() - start + speed = size / duration + logger.info( + f"Downloaded {fmt(size)} in {duration:.2f} seconds, " + f"speed: {fmt(speed)}/s" + ) + + return Metrics( + started_on=started_on, + status="SUCCEEDED", + error=None, + latency=latency, + download_size=size, + duration=duration, + speed=speed, + ) + except RequestException as exc: + error_message = str(exc) + logger.warning( + "error while getting download metrics " + f"(attempt: {attempts}): {exc!s}" + ) + time.sleep(interval * attempts) + + return Metrics( + started_on=datetime.datetime.now(datetime.UTC).isoformat(), + status="ERRORED", + error=error_message, + latency=0.0, + download_size=0, + duration=0, + speed=0.0, + )