Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract tabular dataset loading #274

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 22 additions & 63 deletions podium/datasets/arrow_tabular_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import csv
import itertools
import os
import pickle
Expand All @@ -11,7 +10,8 @@
from podium.field import Field, unpack_fields

from .dataset import Dataset, DatasetBase
from .example_factory import Example, ExampleFactory
from .example_factory import Example
from .tabular_dataset import load_tabular_file


try:
Expand Down Expand Up @@ -205,9 +205,9 @@ def from_tabular_file(
cache_path: str = None,
data_types: Dict[str, Tuple[pa.DataType, pa.DataType]] = None,
chunk_size=10_000,
line2example=None,
skip_header: bool = False,
delimiter=None,
csv_reader_params: Dict = None,
csv_reader_params: Dict = {},
mttk marked this conversation as resolved.
Show resolved Hide resolved
) -> "ArrowDataset":
"""
Loads a tabular file format (csv, tsv, json) as an ArrowDataset.
Expand Down Expand Up @@ -256,15 +256,18 @@ def from_tabular_file(
Maximum number of examples to be loaded before dumping to the on-disk cache
file. Use lower number if memory usage is an issue while loading.

line2example : callable
The function mapping from a file line to Fields.
In case your dataset is not in one of the standardized formats,
you can provide a function which performs a custom split for
each input line.

skip_header : bool
Whether to skip the first line of the input file.
If format is CSV/TSV and 'fields' is a dict, then skip_header
must be False and the data file must have a header.
Default is False.
delimiter: str
Delimiter used to separate columns in a row.
If set to None, the default delimiter for the given format will
be used.

csv_reader_params : Dict
Parameters to pass to the csv reader. Only relevant when
format is csv or tsv.
Expand All @@ -276,62 +279,18 @@ def from_tabular_file(
ArrowDataset
ArrowDataset instance containing the examples from the tabular file.
"""
format = format.lower()
csv_reader_params = {} if csv_reader_params is None else csv_reader_params

with open(os.path.expanduser(path), encoding="utf8") as f:
if format in {"csv", "tsv"}:
delimiter = "," if format == "csv" else "\t"
reader = csv.reader(f, delimiter=delimiter, **csv_reader_params)
elif format == "json":
reader = f
else:
raise ValueError(f"Invalid format: {format}")

if skip_header:
if format == "json":
raise ValueError(
f"When using a {format} file, skip_header must be False."
)
elif format in {"csv", "tsv"} and isinstance(fields, dict):
raise ValueError(
f"When using a dict to specify fields with a {format} "
"file, skip_header must be False and the file must "
"have a header."
)

# skipping the header
next(reader)

# if format is CSV/TSV and fields is a dict, transform it to a list
if format in {"csv", "tsv"} and isinstance(fields, dict):
# we need a header to know the column names
header = next(reader)

# columns not present in the fields dict are ignored (None)
fields = [fields.get(column, None) for column in header]

# fields argument is the same for all examples
# fromlist is used for CSV/TSV because csv_reader yields data rows as
# lists, not strings
example_factory = ExampleFactory(fields)
make_example_function = {
"json": example_factory.from_json,
"csv": example_factory.from_list,
"tsv": example_factory.from_list,
}

make_example = make_example_function[format]

# map each line from the reader to an example
example_iterator = map(make_example, reader)
return ArrowDataset.from_examples(
fields,
example_iterator,
cache_path=cache_path,
data_types=data_types,
chunk_size=chunk_size,
)
example_generator = load_tabular_file(
path, fields, format, line2example, skip_header, csv_reader_params
)

return ArrowDataset.from_examples(
fields,
example_generator,
cache_path=cache_path,
data_types=data_types,
chunk_size=chunk_size,
)

@staticmethod
def _schema_to_data_types(
Expand Down
96 changes: 57 additions & 39 deletions podium/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,48 +83,66 @@ def __init__(
If format is "JSON" and skip_header is True.
"""

format = format.lower()

with open(os.path.expanduser(path), encoding="utf8") as f:

# Skip header prior to custom line2example in case
# the header is in a different format so we don't
# cause an error.
if skip_header:
if format == "json":
raise ValueError(
f"When using a {format} file, skip_header \
must be False."
)
elif format in {"csv", "tsv", "custom"} and isinstance(fields, dict):
raise ValueError(
f"When using a dict to specify fields with a {format} "
"file, skip_header must be False and the file must "
"have a header."
)

# skip the header
next(f)

if line2example is not None:
reader = (line2example(line) for line in f)
format = "custom"
elif format in {"csv", "tsv"}:
delimiter = "," if format == "csv" else "\t"
reader = csv.reader(f, delimiter=delimiter, **csv_reader_params)
elif format == "json":
reader = f
else:
raise ValueError(f"Invalid format: {format}")

# create a list of examples
examples = create_examples(reader, format, fields)

# create a Dataset with lists of examples and fields
examples = load_tabular_file(
path, fields, format, line2example, skip_header, csv_reader_params
)
# Make the examples concrete here by casting to list
examples = list(examples)

super(TabularDataset, self).__init__(examples, fields, **kwargs)
self.finalize_fields()


def load_tabular_file(path, fields, format, line2example, skip_header, csv_reader_params):
mttk marked this conversation as resolved.
Show resolved Hide resolved

with open(os.path.expanduser(path), encoding="utf8") as f:
# create a list of examples
reader = initialize_tabular_reader(
f, format, fields, line2example, skip_header, csv_reader_params
)
examples = create_examples(reader, format, fields)
yield from examples


def initialize_tabular_reader(
file, format, fields, line2example, skip_header, csv_reader_params
):

format = format.lower()

# Skip header prior to custom line2example in case
# the header is in a different format so we don't
# cause an error.
if skip_header:
if format == "json":
raise ValueError(
f"When using a {format} file, skip_header \
must be False."
)
elif format in {"csv", "tsv", "custom"} and isinstance(fields, dict):
raise ValueError(
f"When using a dict to specify fields with a {format} "
"file, skip_header must be False and the file must "
"have a header."
)

# skip the header
next(file)

if line2example is not None:
reader = (line2example(line) for line in file)
format = "custom"
mttk marked this conversation as resolved.
Show resolved Hide resolved
elif format in {"csv", "tsv"}:
delimiter = "," if format == "csv" else "\t"
reader = csv.reader(file, delimiter=delimiter, **csv_reader_params)
elif format == "json":
reader = file
else:
raise ValueError(f"Invalid format: {format}")

return reader


def create_examples(reader, format, fields):
mttk marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates a list of examples from the given line reader and fields (see
Expand Down Expand Up @@ -178,4 +196,4 @@ def create_examples(reader, format, fields):
# map each line from the reader to an example
examples = map(make_example, reader)

return list(examples)
return examples