Skip to content

Commit

Permalink
add sample size to train test split
Browse files Browse the repository at this point in the history
  • Loading branch information
truonghm committed Sep 12, 2023
1 parent 082d649 commit 9eec7c8
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions scripts/split_train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
from dask_ml.model_selection import train_test_split


def read_parquet_and_split(parquet_dir_paths):
"""
Reads multiple parquet directories into a Dask DataFrame and splits it into train and test sets.
Parameters:
- parquet_dir_paths: list of str, paths to the parquet directories
Returns:
- X_train, X_test: Dask DataFrames, the training and test sets
"""
def read_parquet_and_split(parquet_dir_paths, sample_size):
ddfs = []
for path in parquet_dir_paths:
with ProgressBar():
Expand All @@ -25,8 +16,18 @@ def read_parquet_and_split(parquet_dir_paths):
# Combine all the Dask DataFrames
combined_ddf = dd.concat(ddfs)

# Split the data into training and test sets
X_train, X_test = train_test_split(combined_ddf, test_size=0.2, shuffle=True, random_state=42)
# Apply stratified sampling if sample_size is less than 1
if sample_size < 1.0:
total_length = len(combined_ddf)
sample_length = int(total_length * sample_size)
# Assuming the label column is named 'label'
combined_ddf = combined_ddf.sample(frac=sample_size, random_state=42, replace=False).compute()
combined_ddf = dd.from_pandas(combined_ddf, npartitions=combined_ddf.npartitions)

# Split the data into training and test sets (Assuming the label column is named 'label')
X_train, X_test = train_test_split(
combined_ddf, test_size=0.2, shuffle=True, random_state=42, stratify=combined_ddf["label"]
)

return X_train, X_test

Expand All @@ -40,12 +41,15 @@ def list_of_strings(arg):
parser.add_argument("-i", "--inputs", help="Parquet directories to read", required=True, type=list_of_strings)
parser.add_argument("-o", "--output", help="Output directory", required=True)
parser.add_argument("-p", "--prefix", help="Prefix for the output files", required=True)
parser.add_argument("-ss", "--sample-size", help="Sample size", required=False, type=float, default=1.0)

args = parser.parse_args()
parquet_dir_paths = args.inputs
output_dir = args.output
prefix = args.prefix
sample_size = args.sample_size

X_train, X_test = read_parquet_and_split(parquet_dir_paths)
X_train, X_test = read_parquet_and_split(parquet_dir_paths, sample_size)

with ProgressBar():
X_train.repartition(npartitions=1).to_parquet(
Expand Down

0 comments on commit 9eec7c8

Please sign in to comment.