-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
176 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,97 @@ | ||
import argparse | ||
import os | ||
|
||
import dask.dataframe as dd | ||
from dask.diagnostics import ProgressBar | ||
from dask_ml.model_selection import train_test_split | ||
|
||
|
||
def read_parquet_and_split(parquet_dir_paths, sample_size): | ||
ddfs = [] | ||
for path in parquet_dir_paths: | ||
with ProgressBar(): | ||
ddf = dd.read_parquet(path) | ||
ddfs.append(ddf) | ||
|
||
# Combine all the Dask DataFrames | ||
combined_ddf = dd.concat(ddfs) | ||
|
||
# Apply stratified sampling if sample_size is less than 1 | ||
if sample_size < 1.0: | ||
total_length = len(combined_ddf) | ||
sample_length = int(total_length * sample_size) | ||
# Assuming the label column is named 'label' | ||
combined_ddf = combined_ddf.sample(frac=sample_size, random_state=42, replace=False).compute() | ||
combined_ddf = dd.from_pandas(combined_ddf, npartitions=combined_ddf.npartitions) | ||
|
||
# Split the data into training and test sets (Assuming the label column is named 'label') | ||
X_train, X_test = train_test_split( | ||
combined_ddf, test_size=0.2, shuffle=True, random_state=42, stratify=combined_ddf["label"] | ||
) | ||
import pandas as pd | ||
from sklearn.model_selection import train_test_split | ||
|
||
return X_train, X_test | ||
TRAIN_SET = "train_set.csv" | ||
TEST_SET = "test_set.csv" | ||
|
||
|
||
def list_of_strings(arg): | ||
return arg.split(",") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Split parquet data into train and test sets") | ||
parser.add_argument("-i", "--inputs", help="Parquet directories to read", required=True, type=list_of_strings) | ||
parser.add_argument("-o", "--output", help="Output directory", required=True) | ||
parser.add_argument("-p", "--prefix", help="Prefix for the output files", required=True) | ||
parser.add_argument("-ss", "--sample-size", help="Sample size", required=False, type=float, default=1.0) | ||
def get_files_from_subdir(dir_path): | ||
return [ | ||
os.path.join(dir_path, fname) for fname in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, fname)) | ||
] | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Split text data into train and test sets") | ||
parser.add_argument("-i", "--inputs", help="Directories to read", required=True, type=list_of_strings) | ||
parser.add_argument( | ||
"-o", "--output", help="Output directory to contain the train and test set", required=True, type=str | ||
) | ||
parser.add_argument("-ss", "--sample-size", help="Size of the sample to use (0.0 to 1.0)", type=float, default=0.2) | ||
parser.add_argument( | ||
"-ts", "--train-size", help="Size of the training set from sample (0.0 to 1.0)", type=float, default=0.8 | ||
) | ||
|
||
args = parser.parse_args() | ||
parquet_dir_paths = args.inputs | ||
output_dir = args.output | ||
prefix = args.prefix | ||
sample_size = args.sample_size | ||
|
||
X_train, X_test = read_parquet_and_split(parquet_dir_paths, sample_size) | ||
print(f"Inputs : {args.inputs}") | ||
print(f"Sample size : {args.sample_size*100} %") | ||
print(f"Train size : {args.train_size*100} %") | ||
print(f"Test size : {round(1.0 - args.train_size, 1)*100} %") | ||
|
||
with ProgressBar(): | ||
X_train.repartition(npartitions=1).to_parquet( | ||
os.path.join(output_dir, f"{prefix}_train_data.parquet"), | ||
) | ||
X_test.repartition(npartitions=1).to_parquet( | ||
os.path.join(output_dir, f"{prefix}_test_data.parquet"), | ||
good_files = [] | ||
bad_files = [] | ||
labels = [] | ||
|
||
for input_dir in args.inputs: | ||
goodjs_dir = os.path.join(input_dir, "goodjs") | ||
badjs_dir = os.path.join(input_dir, "badjs") | ||
|
||
if os.path.exists(goodjs_dir) and os.path.exists(badjs_dir): | ||
goodjs_files = get_files_from_subdir(goodjs_dir) | ||
badjs_files = get_files_from_subdir(badjs_dir) | ||
|
||
good_files.extend(goodjs_files) | ||
bad_files.extend(badjs_files) | ||
|
||
labels.extend(["goodjs"] * len(goodjs_files)) | ||
labels.extend(["badjs"] * len(badjs_files)) | ||
else: | ||
print(f"Skipping {input_dir} as it doesn't contain both 'goodjs' and 'badjs' directories.") | ||
|
||
all_files = good_files + bad_files | ||
total = len(all_files) | ||
print(f"# before sampling : {total}") | ||
print(f"# goodjs : {len(good_files)}") | ||
print(f"# badjs : {len(bad_files)}") | ||
# Sample from the data if necessary | ||
if args.sample_size < 1.0: | ||
sample_size = int(len(all_files) * args.sample_size) | ||
all_files, _, labels, _ = train_test_split( | ||
all_files, labels, train_size=sample_size, stratify=labels, random_state=42 | ||
) | ||
|
||
print("Train and test data have been saved.") | ||
good_files_after_sample = [all_files[i] for i in range(len(all_files)) if labels[i] == "goodjs"] | ||
bad_files_after_sample = [all_files[i] for i in range(len(all_files)) if labels[i] == "badjs"] | ||
print(f"# after sampling : {len(all_files)}") | ||
print(f"# goodjs sampled : {len(good_files_after_sample)}") | ||
print(f"# badjs sampled : {len(bad_files_after_sample)}") | ||
|
||
# Split the data | ||
X_train, X_test, y_train, y_test = train_test_split( | ||
all_files, labels, train_size=args.train_size, stratify=labels, random_state=42 | ||
) | ||
|
||
output_dir = args.output | ||
|
||
train = pd.DataFrame({"file": X_train, "label": y_train}) | ||
test = pd.DataFrame({"file": X_test, "label": y_test}) | ||
train_path = os.path.join(output_dir, TRAIN_SET) | ||
test_path = os.path.join(output_dir, TEST_SET) | ||
|
||
train.to_csv(train_path, index=False) | ||
test.to_csv(test_path, index=False) | ||
print(f"Train set size : {len(train)}") | ||
print(f"Test set size : {len(test)}") | ||
print(f"Output : [{train_path}, {test_path}]") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import argparse | ||
import hashlib | ||
import os | ||
|
||
import dask.dataframe as dd | ||
from dask.diagnostics import ProgressBar | ||
from tqdm import tqdm | ||
|
||
|
||
def create_text_files(parquet_path, root_dir): | ||
with ProgressBar(): | ||
# Read the parquet file | ||
ddf = dd.read_parquet(parquet_path) | ||
|
||
# Compute to bring into memory (use this cautiously) | ||
df = ddf.compute() | ||
|
||
# Create root directory named after the parquet file | ||
root_path = os.path.join(root_dir, os.path.basename(parquet_path).replace(".parquet", "")) | ||
os.makedirs(root_path, exist_ok=True) | ||
|
||
# Create subdirectories | ||
goodjs_path = os.path.join(root_path, "goodjs") | ||
badjs_path = os.path.join(root_path, "badjs") | ||
os.makedirs(goodjs_path, exist_ok=True) | ||
os.makedirs(badjs_path, exist_ok=True) | ||
|
||
# Create text files | ||
for _, row in tqdm(df.iterrows(), total=len(df)): | ||
label = row["label"] | ||
content = row["content"] | ||
hash_value = hashlib.sha256(content.encode()).hexdigest() | ||
|
||
if label == "good": | ||
file_path = os.path.join(goodjs_path, hash_value) | ||
else: | ||
file_path = os.path.join(badjs_path, hash_value) | ||
|
||
with open(file_path, "w") as f: | ||
f.write(content) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Create text files from a parquet file") | ||
parser.add_argument("-i", "--input", help="Input parquet file path", required=True) | ||
parser.add_argument("-o", "--output", help="Output root directory path", required=True) | ||
|
||
args = parser.parse_args() | ||
parquet_path = args.input | ||
root_dir = args.output | ||
|
||
create_text_files(parquet_path, root_dir) |
File renamed without changes.
File renamed without changes.