Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not wrap LoRA layers with FSDP #1538

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

janEbert
Copy link
Contributor

When wrapping the full Transformer Block, FSDP wraps both trainable and non-trainable parameters. Because of how FSDP is implemented, this results in way higher memory consumption, making the memory savings from LoRA meaningless.

By instead wrapping only torch.nn.Linear modules, we still make use of FSDP but avoid wrapping the LoRA layers.

For clarity, this is the type of warning given by PyTorch for the old code:

/llm-finetuning-example/env/lib/python3.11/site-packages/torch/distributed/fsdp/_wrap_utils.py:174: UserWarning: transformer.h.0 has both parameters with requires_grad=True and False. We do not recommend wrapping such modules since the gradient memory usage will be higher than expected (1451376640 numel instead of 106496 numel before sharding via reduce-scatter). If possible, wrap the frozen parameters with FSDP separately.
The following parameters have requires_grad=True:
['transformer.h.0.attn.attn.lora_A', 'transformer.h.0.attn.attn.lora_B']
The following parameters have requires_grad=False:
['transformer.h.0.norm_1.weight', [...]]

When wrapping the full Transformer `Block`, FSDP wraps both trainable
and non-trainable parameters. This results in way higher memory
consumption, making the memory savings from LoRA meaningless.

By instead wrapping only `torch.nn.Linear` modules, we still make use of
FSDP but avoid wrapping the LoRA layers.
@rasbt
Copy link
Collaborator

rasbt commented Jul 2, 2024

Thanks for the update @janEbert ! This looks good to me. Btw have you done a comparison (re memory usage) before and after by chance?

@janEbert
Copy link
Contributor Author

janEbert commented Jul 2, 2024

I have not, it was only OOM vs. no OOM. 😅
I did try to solve the same problem in a different way using FSDP's ignored_states argument and then manually initializing the LoRA parameters (necessary because of the meta device initialization). However, I ran into similar OOMs with that method, although I could avoid the warning message.

I can supply some comparisons if that would help!

@rasbt
Copy link
Collaborator

rasbt commented Jul 2, 2024

I see, yeah I think we should do some comparisons to make sure it works as intended. If you want to do them, that'd be nice! I suggest perhaps with a small model (phi-2 or so) and a medium-sized model (e.g., Llama 3 8B) at least

@janEbert
Copy link
Contributor Author

janEbert commented Jul 2, 2024

Will do that in the coming days!

@rasbt
Copy link
Collaborator

rasbt commented Jul 2, 2024

That'd be awesome. And pls let me know in case you need any help!

@janEbert
Copy link
Contributor Author

janEbert commented Jul 3, 2024

Forgot to print GPU ranks, but the point should be clear. :)

Before the change:

Memory used: 18.47 GB
Memory used: 18.47 GB
Memory used: 18.50 GB
Memory used: 18.47 GB

After the change:

Memory used: 10.35 GB
Memory used: 10.35 GB
Memory used: 10.35 GB
Memory used: 10.38 GB

Is this helpful enough or would you like to see more detailed stats?

@williamFalcon
Copy link
Contributor

cc @awaelchli

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

@janEbert Looks awesome, which model is that?

I am also rerunning some of the models in the config hub and will update the numbers accordingly!

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

I just ran a quick comparison on an 4xA10G machine to see if I can reproduce the config hub performance

| falcon-7b/lora.yaml               | falcon-7b              | 4      | 512            | 1                | 4xA10G  | 24.94 min        | $2.0 | 16.69 GB    | 0.945           | 2.573                 | 26.4%        

For some reason, it's been really slow, but it can be a machine issue I have to look into. But here hare some numbers I am getting for the code in this PR:

Epoch 1 | iter 1 step 0 | loss train: 2.299, val: n/a | iter time: 9980.09 ms
Epoch 1 | iter 2 step 1 | loss train: 9.848, val: n/a | iter time: 10256.36 ms (step)
Epoch 1 | iter 3 step 1 | loss train: 14.262, val: n/a | iter time: 10179.08 ms
Epoch 1 | iter 4 step 2 | loss train: 11.866, val: n/a | iter time: 10034.30 ms (step)
Epoch 1 | iter 5 step 2 | loss train: 14.844, val: n/a | iter time: 10490.95 ms
Epoch 1 | iter 6 step 3 | loss train: 17.512, val: n/a | iter time: 10738.12 ms (step)
Epoch 1 | iter 7 step 3 | loss train: 14.514, val: n/a | iter time: 10573.39 ms
Epoch 1 | iter 8 step 4 | loss train: 11.069, val: n/a | iter time: 10545.70 ms (step)
Epoch 1 | iter 9 step 4 | loss train: 11.084, val: n/a | iter time: 10077.49 ms
Epoch 1 | iter 10 step 5 | loss train: 11.105, val: n/a | iter time: 10593.24 ms (step)

If I compare it to the performance before, I notice that the non-step steps are about 30% slower. E.g. from the main branch:

Epoch 1 | iter 1 step 0 | loss train: 2.299, val: n/a | iter time: 7121.10 ms
Epoch 1 | iter 2 step 1 | loss train: 2.163, val: n/a | iter time: 10766.50 ms (step)
Epoch 1 | iter 3 step 1 | loss train: 1.705, val: n/a | iter time: 7182.87 ms
Epoch 1 | iter 4 step 2 | loss train: 1.960, val: n/a | iter time: 10861.94 ms (step)
Epoch 1 | iter 5 step 2 | loss train: 1.891, val: n/a | iter time: 7214.43 ms
Epoch 1 | iter 6 step 3 | loss train: 1.468, val: n/a | iter time: 10887.67 ms (step)
Epoch 1 | iter 7 step 3 | loss train: 1.626, val: n/a | iter time: 7162.67 ms
Epoch 1 | iter 8 step 4 | loss train: 1.167, val: n/a | iter time: 10712.47 ms (step)
Epoch 1 | iter 9 step 4 | loss train: 1.764, val: n/a | iter time: 7206.13 ms
Epoch 1 | iter 10 step 5 | loss train: 2.023, val: n/a | iter time: 10798.46 ms (step)
Epoch 1 | iter 11 step 5 | loss train: 1.892, val: n/a | iter time: 7213.32 ms
Epoch 1 | iter 12 step 6 | loss train: 2.678, val: n/a | iter time: 10819.67 ms (step)
Epoch 1 | iter 13 step 6 | loss train: 2.245, val: n/a | iter time: 7164.10 ms

Just curious, have you observed something similar? (In this case we maybe could also think about a "optimize runtime|memory" setting here.

For comparison, the single-GPU speeds:

Epoch 1 | iter 1 step 0 | loss train: 1.281, val: n/a | iter time: 562.95 ms
Epoch 1 | iter 2 step 0 | loss train: 1.338, val: n/a | iter time: 173.90 ms
Epoch 1 | iter 3 step 0 | loss train: 1.738, val: n/a | iter time: 110.99 ms
Epoch 1 | iter 4 step 0 | loss train: 1.681, val: n/a | iter time: 269.70 ms
Epoch 1 | iter 5 step 0 | loss train: 1.830, val: n/a | iter time: 160.60 ms
Epoch 1 | iter 6 step 0 | loss train: 1.871, val: n/a | iter time: 99.58 ms
Epoch 1 | iter 7 step 0 | loss train: 1.870, val: n/a | iter time: 93.07 ms
Epoch 1 | iter 8 step 1 | loss train: 1.775, val: n/a | iter time: 319.57 ms (step)
Epoch 1 | iter 9 step 1 | loss train: 1.873, val: n/a | iter time: 94.81 ms
Epoch 1 | iter 10 step 1 | loss train: 1.797, val: n/a | iter time: 329.24 ms
Epoch 1 | iter 11 step 1 | loss train: 1.693, val: n/a | iter time: 258.53 ms

So I am thinking this could be due to a slow interconnect at the GPUs. I will look into it and do some more experiments.

@Andrei-Aksionov
Copy link
Collaborator

Why does the loss train increases (for the code from this PR)?
From 2.299 up to 17.512.

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

Not sure. I observed it with Phi-2 too:

Main branch:

litgpt finetune_lora checkpoints/microsoft/phi-2/ --devices 4
Epoch 1 | iter 1 step 0 | loss train: 2.424, val: n/a | iter time: 5537.97 ms
Epoch 1 | iter 2 step 0 | loss train: 2.519, val: n/a | iter time: 5578.44 ms
Epoch 1 | iter 3 step 0 | loss train: 2.646, val: n/a | iter time: 5563.88 ms
Epoch 1 | iter 4 step 1 | loss train: 2.516, val: n/a | iter time: 6942.96 ms (step)
Epoch 1 | iter 5 step 1 | loss train: 2.467, val: n/a | iter time: 5483.51 ms

PR branch:

litgpt finetune_lora checkpoints/microsoft/phi-2/ --devices 4
Epoch 1 | iter 1 step 0 | loss train: 2.424, val: n/a | iter time: 7818.10 ms
Epoch 1 | iter 2 step 0 | loss train: 6.647, val: n/a | iter time: 8075.87 ms
Epoch 1 | iter 3 step 0 | loss train: 8.358, val: n/a | iter time: 7731.87 ms
Epoch 1 | iter 4 step 1 | loss train: 9.103, val: n/a | iter time: 7654.43 ms (step)
Epoch 1 | iter 5 step 1 | loss train: 11.207, val: n/a | iter time: 7824.40 ms

Something I need to investigate more in the next few days. I'll try this also on a different machine since I think the A10G machine has very slow GPU connections.

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

Why does the loss train increases (for the code from this PR)? From 2.299 up to 17.512.

I am curious if the whole Block was maybe accidentally trainable (instead of just the LoRA linear layers) before, which could explain the sharper loss decrease. But we should have tests for that, and I need to double-check that with a debugger. Just leaving this here as a note to myself so I can pick it up next week.

@janEbert
Copy link
Contributor Author

janEbert commented Jul 3, 2024

Funny enough, in an entirely unrelated example, I've also noticed PyTorch Distributed becoming increasingly less reproducible for slightly changed settings the higher the PyTorch version. Could that maybe be the case here as well? Do you get a reproducible loss when running the same version of the code?

@janEbert Looks awesome, which model is that?

It's Mistral-7B-Instruct-v0.3 on a very small (4 samples) dummy JSON dataset, global batch size = 4, all other settings default.

@janEbert
Copy link
Contributor Author

janEbert commented Jul 3, 2024

BTW my iteration speed is also slightly slower. I'll check if the version with ignored_states performs better tomorrow.

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

That's a good point, but I think there is a different issue here that I am not understanding yet 😅. When I reran the code I observed basically the same higher loss. That's also independent of the model I tried.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants