Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added debugging to train.py #88

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 73 additions & 33 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,84 @@
import argparse
import json
from typing import Tuple, Optional, Union

import re
from nltk.corpus import stopwords

class MappingType(Enum):
MLP = 'mlp'
Transformer = 'transformer'



class ClipCocoDataset(Dataset):
def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2",
normalize_prefix=False):
self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
self.prefix_length = prefix_length
self.normalize_prefix = normalize_prefix
self.stop_words = set(stopwords.words('english'))
# Load data
with open(data_path, 'rb') as f:
all_data = pickle.load(f)
print("Data size is %0d" % len(all_data["clip_embedding"]))
sys.stdout.flush()

self.prefixes = all_data["clip_embedding"]
captions_raw = all_data["captions"]
self.image_ids = [caption["image_id"] for caption in captions_raw]
self.captions = [caption['caption'] for caption in captions_raw]

# Tokenized captions file path
tokens_file_path = f"/kaggle/working/{os.path.basename(data_path).split('.')[0]}_{gpt2_type}_tokens.pkl"

if os.path.isfile(tokens_file_path):
print("Loading tokenized captions from pickle file...")
with open(tokens_file_path, 'rb') as f:
self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f)
else:
print("Tokenizing captions and saving to pickle file...")
self.captions_tokens = []
self.caption2embedding = []
max_seq_len = 0

for caption in captions_raw:
processed_caption = self.preprocess_caption(caption['caption'])
tokens = torch.tensor(self.tokenizer.encode(processed_caption), dtype=torch.int64)
self.captions_tokens.append(tokens)
self.caption2embedding.append(caption["clip_embedding"]) # Should be an index
max_seq_len = max(max_seq_len, tokens.shape[0])

# Save tokenized captions to pickle file
with open(tokens_file_path, 'wb') as f:
pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f)

# Validate indices
valid_indices = [i for i, idx in enumerate(self.caption2embedding) if idx < len(self.prefixes)]
if len(valid_indices) < len(self.caption2embedding):
print(f"Found {len(self.caption2embedding) - len(valid_indices)} invalid indices. Filtering out invalid captions.")
self.captions_tokens = [self.captions_tokens[i] for i in valid_indices]
self.caption2embedding = [self.caption2embedding[i] for i in valid_indices]
self.image_ids = [self.image_ids[i] for i in valid_indices]
self.captions = [self.captions[i] for i in valid_indices]

# Compute max sequence length based on tokenized data
all_len = torch.tensor([len(tokens) for tokens in self.captions_tokens]).float()
self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))
def preprocess_caption(self, caption: str) -> str:
"""
Preprocesses the caption by normalizing case, removing special characters,
redundant white spaces, and stopwords.
"""
# Convert to lowercase
caption = caption.lower()
# Remove special characters
caption = re.sub(r"[^\w\s]", "", caption) # Retain only letters, digits, and spaces
# Remove redundant white spaces
caption = re.sub(r"\s+", " ", caption).strip()
# Remove stopwords
words = caption.split()
caption = " ".join([word for word in words if word not in self.stop_words])
return caption
def __len__(self) -> int:
return len(self.captions_tokens)

Expand All @@ -32,49 +101,20 @@ def pad_tokens(self, item: int):
elif padding < 0:
tokens = tokens[:self.max_seq_len]
self.captions_tokens[item] = tokens
mask = tokens.ge(0) # mask is zero where we out of sequence
mask = tokens.ge(0) # Mask is zero where we are out of sequence
tokens[~mask] = 0
mask = mask.float()
mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask
mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # Adding prefix mask
return tokens, mask

def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
tokens, mask = self.pad_tokens(item)
prefix = self.prefixes[self.caption2embedding[item]]
prefix = self.prefixes[self.caption2embedding[item]] # Use index to get embedding
if self.normalize_prefix:
prefix = prefix.float()
prefix = prefix / prefix.norm(2, -1)
return tokens, mask, prefix

def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2",
normalize_prefix=False):
self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
self.prefix_length = prefix_length
self.normalize_prefix = normalize_prefix
with open(data_path, 'rb') as f:
all_data = pickle.load(f)
print("Data size is %0d" % len(all_data["clip_embedding"]))
sys.stdout.flush()
self.prefixes = all_data["clip_embedding"]
captions_raw = all_data["captions"]
self.image_ids = [caption["image_id"] for caption in captions_raw]
self.captions = [caption['caption'] for caption in captions_raw]
if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"):
with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f:
self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f)
else:
self.captions_tokens = []
self.caption2embedding = []
max_seq_len = 0
for caption in captions_raw:
self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64))
self.caption2embedding.append(caption["clip_embedding"])
max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])
# self.max_seq_len = max_seq_len
with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f:
pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f)
all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))


class MLP(nn.Module):
Expand Down