From 3f2707afc74d98972485fcf5e0ecc22562c4181d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian-Robert=20St=C3=B6ter?= Date: Wed, 13 Jan 2021 09:13:11 +0100 Subject: [PATCH] Add support for multilabel-classification datasets #11 (#30) * clean up io, introduce utils * add support for multi-label classifcation datasets. rename data generator, --- README.md | 16 +++++----- gbif_dl/dataloaders/torch.py | 2 +- gbif_dl/generators/api.py | 19 +++++++---- gbif_dl/generators/dwca.py | 14 +++++--- gbif_dl/io.py | 62 +++++++++++------------------------- gbif_dl/utils.py | 55 ++++++++++++++++++++++++++++++++ 6 files changed, 106 insertions(+), 62 deletions(-) create mode 100644 gbif_dl/utils.py diff --git a/README.md b/README.md index b095d04..930db2e 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ generator: ```python import gbif_dl -url_generator = gbif_dl.api.generate_urls( +data_generator = gbif_dl.api.generate_urls( queries=queries, label="speciesKey", ) @@ -69,7 +69,7 @@ necessarily have to be part of the query attributes. The `label` is later used t Iterating over the generator now yields the media data returning a few thousand urls. ```python -for i in url_generator: +for i in data_generator: print(i) ``` @@ -103,7 +103,7 @@ Very often users won't be using all media downloads from a given query since thi In the following example, we will receive a balanced dataset assembled from `3 species * 2 datasets = 6 streams` and only get minumum number of total samples from all 6 streams: ```python -url_generator = gbif_dl.api.generate_urls( +data_generator = gbif_dl.api.generate_urls( queries=queries, label="speciesKey", nb_samples=-1, @@ -119,7 +119,7 @@ For other, more advanced, use-cases users can add more constraints: The following dataset consist of exactly 1000 samples for which the distribution of `speciesKey` is maintained from the full query of all samples. Furthermore, we only allow a maxmimum of 800 samples per species. ```python -url_generator = gbifmediads.api.generate_urls( +data_generator = gbifmediads.api.generate_urls( queries=queries, label="speciesKey", nb_samples=1000, @@ -135,10 +135,10 @@ A url generator can also be created from a GBIF download link given a registered * `dwca_root_path`: Set root path where to store the DWCA zip files. Defaults to None, which results in the creation of a temporary directory, If the path and DWCA archive already exist, it will not be downloaded again. -The following example creates a url_generator with the the same output class label as in the example above. +The following example creates a data_generator with the the same output class label as in the example above. ```python -url_generator = gbif_dl.dwca.generate_urls( +data_generator = gbif_dl.dwca.generate_urls( "10.15468/dl.vnm42s", dwca_root_path="dwcas", label="speciesKey" ) ``` @@ -147,7 +147,7 @@ url_generator = gbif_dl.dwca.generate_urls( Downloading from a url generator can simply be done by running. ```python -gbif_dl.io.download(url_generator, root="my_dataset") +gbif_dl.io.download(data_generator, root="my_dataset") ``` The downloader provides very fast download speeds by using an async queue. Some fail-safe functionality is provided by setting the number of `retries`, default to 3. @@ -160,7 +160,7 @@ The downloader provides very fast download speeds by using an async queue. Some ```python from gbif_dl.dataloaders.torch import GBIFImageDataset -dataset = GBIFImageDataset(root='my_dataset', generator=url_generator, download=True) +dataset = GBIFImageDataset(root='my_dataset', generator=data_generator, download=True) ``` > ⚠️ Note that we do not provide train/validation/test splits of the dataset as this would be more useful to design specifically to the downstream task. diff --git a/gbif_dl/dataloaders/torch.py b/gbif_dl/dataloaders/torch.py index 92605ed..3455f03 100644 --- a/gbif_dl/dataloaders/torch.py +++ b/gbif_dl/dataloaders/torch.py @@ -10,7 +10,7 @@ class GBIFImageDataset(torchvision.datasets.ImageFolder): - """GBIF Image Dataset + """GBIF Image Dataset for multi-class classification Args: root (str): diff --git a/gbif_dl/generators/api.py b/gbif_dl/generators/api.py index cab654f..7871c15 100644 --- a/gbif_dl/generators/api.py +++ b/gbif_dl/generators/api.py @@ -18,7 +18,7 @@ def gbif_query_generator( page_limit: int = 300, mediatype: str = 'StillImage', - label: str = 'speciesKey', + label: Optional[str] = None, *args, **kwargs ) -> MediaData: """Performs media queries GBIF yielding url and label @@ -26,7 +26,9 @@ def gbif_query_generator( Args: page_limit (int, optional): GBIF api uses paging which can be modified. Defaults to 300. mediatype (str, optional): Sets GBIF mediatype. Defaults to 'StillImage'. - label (str, optional): Sets label. Defaults to 'speciesKey'. + label (str, optional): Output label name. + Defaults to `None` which yields all metadata. + Returns: str: [description] @@ -57,10 +59,15 @@ def gbif_query_generator( media['identifier'].encode('utf-8') ).hexdigest() + if label is not None: + output_label = str(metadata.get(label)) + else: + output_label = metadata + yield { "url": media['identifier'], "basename": hashed_url, - "label": str(metadata.get(label)) + "label": output_label, } if resp['endOfRecords']: @@ -97,7 +104,7 @@ def dproduct(dicts): def generate_urls( queries: Dict, - label: str = "speciesKey", + label: Optional[str] = None, split_streams_by: Optional[Union[str, List]] = None, nb_samples_per_stream: Optional[int] = None, nb_samples: Optional[int] = None, @@ -111,8 +118,8 @@ def generate_urls( Args: queries (Dict): dictionary of queries supported by the GBIF api - label (str, optional): label identfier, according to query api. - Defaults to "speciesKey". + label (str, optional): Output label name. + Defaults to `None` which yields all metadata. nb_samples (int): Limit the total number of samples retrieved from the API. When set to -1 and `split_streams_by` is not `None`, diff --git a/gbif_dl/generators/dwca.py b/gbif_dl/generators/dwca.py index 0e5d64a..29228e7 100644 --- a/gbif_dl/generators/dwca.py +++ b/gbif_dl/generators/dwca.py @@ -58,10 +58,15 @@ def dwca_generator( url.encode('utf-8') ).hexdigest() + if label is not None: + output_label = str(row.data.get(gbifqualname + label)) + else: + output_label = row.data + yield { "url": url, "basename": hashed_url, - "label": str(row.data.get(gbifqualname + label)) + "label": output_label, } if delete: @@ -109,8 +114,8 @@ def _is_doi(identifier: str) -> bool: def generate_urls( identifier: str, dwca_root_path=None, - label: Optional[str] = "speciesKey", - mediatype: Optional[str] = "StillImage", + label: Optional[str] = None, + mediatype: Optional[str] = "StillImage" delete: Optional[bool] = False ): """Generate GBIF items from DOI or GBIF download key @@ -120,7 +125,8 @@ def generate_urls( dwca_root_path (str, optional): Set root path where to store Darwin Core zip files. Defaults to None, which results in the creation of temporary directries - label (str): output label + label (str, optional): Output label name. + Defaults to `None` which yields all metadata. mediatype (str, optional): Sets GBIF mediatype. Defaults to 'StillImage'. the creation of temporary directories. delete (bool, optional): Delete darwin core archive when finished. diff --git a/gbif_dl/io.py b/gbif_dl/io.py index 5bbf601..c7738cb 100644 --- a/gbif_dl/io.py +++ b/gbif_dl/io.py @@ -4,6 +4,11 @@ from pathlib import Path from typing import AsyncGenerator, Callable, Generator, Union, Optional import sys +import json +import functools + +from attr import dataclass + if sys.version_info >= (3, 8): from typing import TypedDict # pylint: disable=no-name-in-module @@ -18,7 +23,7 @@ import aiostream from aiohttp_retry import RetryClient, ExponentialRetry from tqdm.asyncio import tqdm - +from .utils import watchdog, run_async class MediaData(TypedDict): """ Media dict representation received from api or dwca generators""" @@ -54,8 +59,13 @@ async def download_single( """ url = item['url'] - # check for path - label_path = Path(root, item['label']) + # create subfolder when label is a single str + if isinstance(item['label'], str): + label_path = Path(root, item['label']) + # otherwise make it a flat file hierarchy + else: + label_path = Path(root) + label_path.mkdir(parents=True, exist_ok=True) check_files_with_same_basename = label_path.glob(item['basename'] + "*") @@ -91,6 +101,10 @@ async def download_single( async with aiofiles.open(file_path, "+wb") as f: await f.write(content) + if isinstance(item['label'], dict): + json_path = (label_path / item['basename']).with_suffix('.json') + async with aiofiles.open(json_path, mode='+w') as fp: + await fp.write(json.dumps(item['label'])) async def download_queue( queue: asyncio.Queue, @@ -133,8 +147,8 @@ async def download_queue( async def download_from_asyncgen( items: AsyncGenerator, root: str = "data", - tcp_connections: int = 256, - nb_workers: int = 256, + tcp_connections: int = 64, + nb_workers: int = 64, batch_size: int = 16, retries: int = 3, verbose: bool = False, @@ -213,44 +227,6 @@ async def download_from_asyncgen( w.cancel() -def get_or_create_eventloop(): - try: - return asyncio.get_event_loop() - except RuntimeError as ex: - if "There is no current event loop in thread" in str(ex): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return asyncio.get_event_loop() - -class RunThread(threading.Thread): - def __init__(self, func, args, kwargs): - self.func = func - self.args = args - self.kwargs = kwargs - super().__init__() - - def run(self): - self.result = asyncio.run(self.func(*self.args, **self.kwargs)) - - -def run_async(func, *args, **kwargs): - """async wrapper to detect if asyncio loop is already running - - This is useful when already running in async thread. - """ - try: - loop = get_or_create_eventloop() - except RuntimeError: - loop = None - if loop and loop.is_running(): - thread = RunThread(func, args, kwargs) - thread.start() - thread.join() - return thread.result - else: - return asyncio.run(func(*args, **kwargs)) - - def download( items: Union[Generator, AsyncGenerator, Iterable], root: str = "data", diff --git a/gbif_dl/utils.py b/gbif_dl/utils.py new file mode 100644 index 0000000..c470c75 --- /dev/null +++ b/gbif_dl/utils.py @@ -0,0 +1,55 @@ +import asyncio +import functools +import threading + +def watchdog(afunc): + """Stops all tasks if there is an error""" + @functools.wraps(afunc) + async def run(*args, **kwargs): + try: + await afunc(*args, **kwargs) + except asyncio.CancelledError: + return + except Exception as err: + print(f'exception {err}') + asyncio.get_event_loop().stop() + return run + + +def get_or_create_eventloop(): + try: + return asyncio.get_event_loop() + except RuntimeError as ex: + if "There is no current event loop in thread" in str(ex): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return asyncio.get_event_loop() + + +class RunThread(threading.Thread): + def __init__(self, func, args, kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + super().__init__() + + def run(self): + self.result = asyncio.run(self.func(*self.args, **self.kwargs)) + + +def run_async(func, *args, **kwargs): + """async wrapper to detect if asyncio loop is already running + + This is useful when already running in async thread. + """ + try: + loop = get_or_create_eventloop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + thread = RunThread(func, args, kwargs) + thread.start() + thread.join() + return thread.result + else: + return asyncio.run(func(*args, **kwargs))