Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny numerical differences, Weight updates not perfectly matching #16

Open
Ar-Kareem opened this issue Nov 28, 2022 · 2 comments
Open

Comments

@Ar-Kareem
Copy link

Ar-Kareem commented Nov 28, 2022

Hi, thanks for this amazing library.

I saw one tiny issue which is that the final weights of the model is different when training with multiple sub_batches per step vs 1 big_batch per step. I'm not sure if such numerical differences are expected when using this library.

I'm using clip with contrastive loss, here's my quick experimental code that I made sure to run multiple times and it results in exactly the same output each time:
(note: I'm using CLIP with 151 million parameters and a dataset of only 32 samples for experimental purposes)

model1 = train_clip_normally(epochs=1, batch_size=16)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 0.3163

Above we see that training for two sub_batches of 8 vs training for 1 batch of 16 gives a tiny different in the norm of the weights of the two models.

model1 = train_clip_normally(epochs=1, batch_size=16)
model2 = train_clip_gradcache(epochs=1, batch_size=16, batches_per_backward=1)
print(calc_model_param_difference(model1, model2)) # RETURN: 0

Above we see that the models are equivalent when making gradcache perform a backward every batch

model1 = train_clip_gradcache(epochs=1, batch_size=4, batches_per_backward=4)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 0.3105

Above we see the difference still exists for two different gradcache batch sizes

However this library is still working amazingly as if I compare it with normally training with whatever maximum batch size fits in my GPU, I get a huge difference (which is expected and exactly why I need this library) as seen below

model1 = train_clip_normally(epochs=1, batch_size=8)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 363.2708

Below is my code in case the problem is with it:

def train_clip_normally(epochs, batch_size):
    dl = torch.utils.data.DataLoader(d, batch_size=batch_size, shuffle=False)
    model = MyCLIPModel("openai/clip-vit-base-patch32").to('cuda:1')
    optimizer = torch.optim.Adam(model.parameters())
    for e in range(epochs):
        cliptrain.train_epoch(model, optimizer, processor, dl)
    return model

def train_clip_gradcache(epochs, batch_size, batches_per_backward):
    dl = torch.utils.data.DataLoader(d, batch_size=batch_size, shuffle=False)
    model = MyCLIPModel("openai/clip-vit-base-patch32").to('cuda:1')
    optimizer = torch.optim.Adam(model.parameters())
    for e in range(epochs):
        cliptrain.ClipModelClone.grad_cache_train(model, optimizer, processor, dl, batches_per_backward=batches_per_backward)
    return model

def calc_model_param_difference(model1, model2):
    diff = 0
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        diff += torch.norm(p1.data - p2.data)
    return diff

from grad_cache.functional import cached, cat_input_tensor

def grad_cache_train(model, optimizer, processor, dataloader, batches_per_backward):
    cache_x = []
    cache_y = []
    closures_x = []
    closures_y = []

    for step, sub_batch in enumerate(dataloader):  
        inputs = processor(text=sub_batch['text'], return_tensors="pt", padding=True, truncation=True)
        inputs['input_ids'] = inputs['input_ids'].to(model.device)
        inputs['attention_mask'] = inputs['attention_mask'].to(model.device)
        inputs['pixel_values'] = sub_batch['image'].to(model.device)
        inputs['return_loss'] = True

        print('step', step)
        rx, cx = call_text_model(model, inputs)
        ry, cy = call_vision_model(model, inputs)
        
        cache_x.append(rx)
        cache_y.append(ry)
        closures_x.append(cx)
        closures_y.append(cy)
        
        if (step + 1) % batches_per_backward == 0:
            print('BACKWARD!')
            loss = grad_cat_loss(cache_x, cache_y, model.logit_scale)
            loss.backward()
            
            for f, r in zip(closures_x, cache_x):
                f(r)
            for f, r in zip(closures_y, cache_y):
                f(r)

            cache_x = []
            cache_y = []
            closures_x = []
            closures_y = []
        
            optimizer.step()
            optimizer.zero_grad()

@cat_input_tensor
def grad_cat_loss(text_embeds, image_embeds, logit_scale):
    sim = torch.matmul(text_embeds, image_embeds.t()) * logit_scale.exp()
    return clip_loss(sim)

@cached
def  call_text_model(model, input):
    return model.forward_text(**input)

@cached
def  call_vision_model(model, input):
    return model.forward_visual(**input)

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(similarity)
    image_loss = contrastive_loss(similarity.t())
    return (caption_loss + image_loss) / 2.0

def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))

@zzk2021
Copy link

zzk2021 commented Jan 26, 2023

Hi! I alse train CLIP in this way.
The reason that causes inconsistent, I think it's BN layer do not synchronous. You should replace BN with GN.
Besides, I think you should sample without replacement. That's means you need design a sampler. This is important if you only use CLIP's loss.
If you don't do this, It will get same ID in one iteration. It's a wrong optimization goal for CLIP.

@zzk2021
Copy link

zzk2021 commented Jan 26, 2023

I was over thinking. In pytorch,default option is replacement=False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants