Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of EfficientNetv2 and MNASNet #198

Merged
merged 36 commits into from
Aug 24, 2022

Conversation

theabhirath
Copy link
Member

@theabhirath theabhirath commented Aug 11, 2022

This is an implementation of EfficientNetv2 and MNASNet. There's clearly some code duplication between the structures of the efficientnet function and the mobilenetvX functions, but figuring out how to unify that API and its design is perhaps best left for another PR.

TODO

  • Figure out a way to unify the EfficientNet and EfficientNetv2 lower level API into a single efficientnet function
  • More elaborate docs for mbconv and fused_mbconv

P.S. memory issues mean that the tests are more fragmented than they ought to be, not sure how we can go about addressing those in the short/medium term

@theabhirath theabhirath added the new-model Request or implementation of a new model label Aug 11, 2022
@theabhirath theabhirath added this to the 0.8 milestone Aug 11, 2022
@theabhirath theabhirath changed the base branch from master to cl/fix August 11, 2022 18:24
@theabhirath theabhirath changed the base branch from cl/fix to master August 11, 2022 18:24
@theabhirath theabhirath marked this pull request as draft August 11, 2022 18:26
@theabhirath theabhirath removed the request for review from darsnack August 11, 2022 18:27
# building inverted residual blocks
for (k, t, c, reduction, activation, stride) in configs
for (i, (k, t, c, reduction, activation, stride)) in enumerate(configs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I work on refactoring the EfficientNet, this part is something that's been annoying me so I decided to put it up. torchvision has a cool feature from other papers (see https://pytorch.org/blog/torchvision-mobilenet-v3-implementation/ for a proper explanation) where they use dilations and a reduced tail (i.e. last three blocks) for some engineering gains. My code works but looks terribly ugly - any idea how to make this look prettier?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull it out into a utility function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's definitely a possibility, but I don't like the fact that this changes the implementation details for MobileNetv3 so fundamentally that unifying the code for the three MobileNets now becomes a nightmare. torchvision uses class variables, so they are able to get away with writing this for example - any creative ways to do something similar in Julia?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that a utility function? :P Perhaps I'm missing something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that a utility function? :P Perhaps I'm missing something.

It is, but torchvision writes the code for the three MobileNets differently, and I wanted to avoid that if possible in Julia. But what I more meant was that the tail dilation and reduced tail are built into the config dict because all of those variables are in the same function scope

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the extent of the differing logic here the conditional below? If so, perhaps it could be its own function, parameterized on i or some boolean to indicate the calculation is being done for a tail layer. This function has a branch for whether you want the dilations and/or reduced dimensions at the tail. Won't be the prettiest thing, but it doesn't have to be either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This went away in the recent refactor, but keeping this conversation unresolved as a reminder that maybe some day we'll find a way to bake this into the configuration dict

@theabhirath theabhirath force-pushed the effnetv2 branch 2 times, most recently from bda527b to b57e6f6 Compare August 15, 2022 17:55
@theabhirath
Copy link
Member Author

theabhirath commented Aug 16, 2022

Side note - using basic_conv_bn for the Inception models seems to have fixed their gradient times somehow, which are now much more manageable. Maybe the extra biases were causing a slowdown

@IanButterworth
Copy link
Contributor

Switching to EfficientNet(:b0; ... I get

Screenshot 2023-12-08 at 3 02 42 PM

@darsnack
Copy link
Member

darsnack commented Dec 8, 2023

Okay I'll make my own script over the weekend and sanity check.

@IanButterworth
Copy link
Contributor

Thanks!

Comparing to this example https://www.kaggle.com/code/paultimothymooney/efficientnetv2-with-tensorflow?scriptVersionId=120655976&cellId=27

  • It uses SGD, not Adam.
  • Learning rate is 0.005 vs the default Adam 0.001
  • Primary momentum is the same
  • It uses label smoothing

And performance maxes out within a few epochs

@IanButterworth
Copy link
Contributor

@darsnack i was just wondering whether you have had a chance to test this or check my code for any obvious issues. Thanks

@IanButterworth
Copy link
Contributor

Using IanButterworth/EfficientNet-Training@ec80084 and switching just the model to ResNet gets a much more reasonable result

model = ResNet(18; pretrain=false, inchannels=3, nclasses=length(labels)) |> device
Screenshot 2023-12-13 at 12 46 27 PM

Other than there being some issue with the model, could it be that EfficientNet requires some subtle difference in the loss function or optimizer setups?

@IanButterworth
Copy link
Contributor

EfficientNet (not V2)

model = EfficientNet(:b0; pretrain=false, inchannels=3, nclasses=length(labels)) |> device
Screenshot 2023-12-13 at 12 59 21 PM

@theabhirath
Copy link
Member Author

I was trying a lot of these on CIFAR-10 during my GSoC and was facing issues, including with ResNet – the accuracies were nowhere near what I could get with PyTorch, for example. I remember trying to debug this but then had to give up since I got occupied with other commitments. IIRC at the time one theory was that there might be something wrong with the gradients, but we didn't nearly manage to get enough training examples on the GPU to confirm. I could try running these again if I got a GPU somehow 😅

The difference between EfficientNet and ResNet is weird, but not unexpected because they do use different kinds of layers. Maybe MobileNet vs EfficientNet tells us more? The underlying code over there is the exact same because of the way the models are constructed. Even for ResNet, though, the loss curve looks kinda bumpy...

@IanButterworth
Copy link
Contributor

Interesting.

I was considering setting up a job to do the same for all appropriate models in Metalhead on my machine, and tracking any useful metrics while doing so for a comparison table. I have a relatively decent setup so it might not take too long.

I can understand that full training isn't a part of CI because of resources, but I think it needs to be demonstrated that these models are trainable somewhere public.

@IanButterworth
Copy link
Contributor

I've started putting together a benchmarking script here. #264
It'd be great to get feedback on it.

@ToucheSir ToucheSir mentioned this pull request Dec 15, 2023
2 tasks
@ToucheSir
Copy link
Member

As @theabhirath mentioned, this has come up before but we never got to the bottom of why because of time constraints. If you have some bandwidth to do some diagnostic runs looking at whether gradients, activations etc are misbehaving and where they're misbehaving, we could narrow the problem down to something actionable. Whether that be a bug in Metalhead itself or something further upstream like Flux, Zygote or CUDA.

@IanButterworth
Copy link
Contributor

Ok. I think my minimum goal then is to characterise the issue (is it all models) and make it clear in the docs which ones cannot be expected to be trainable.

As someone coming to this package to use it to train on a custom dataset it's been quite a time sync just to get to that understanding.

@ToucheSir
Copy link
Member

ToucheSir commented Dec 15, 2023

If it's a either-or deal between that and finding out why a specific model (e.g. EfficientNet) has problems, I'd lean towards the latter. I suspect any issues may lie at least at the level of some of the shared layer helpers, so addressing those would help other models too. This kind of deep debugging is also something myself and probably @darsnack are less likely to have time to do, whereas running a convergence test for all models would be more straightforward. But if the goal is to do both, that sounds good to me.

@IanButterworth
Copy link
Contributor

I lean more towards what I'm currently trying because I'm not sure I have the knowledge/skill set to dive in and debug.

Tbh if I find a model that trains and is performant I may declare victory, but share findings.

Or maybe I strike lucky in a dive.

@darsnack
Copy link
Member

Sorry for the late reply. I did start writing a script, but I never gotten around to starting my testing. I kept meaning to reply as soon as I did that, but it's been a busy week.

Looks like we might have some duplicate work. Either way, I uploaded my code to a repo here. I've add all of you as collaborators. The biggest help right now would be is someone can add in all the remaining models + configs to the main.jl script. Output from the script is being logged here. It should be a public wandb workspace that we can all view.

@darsnack
Copy link
Member

Based on the current runs that finished, it looks like all the EfficientNetv2 models have some bug. Only the :small variant trains to completion. The rest all drop to NaN loss at some point, with the larger variants dropping to NaN in the first epoch.

ResNets have the quirk Abhirath noticed during the summer where the models starts off fine then start to overfit to the training data.

I also have AlexNet and EfficientNet (v1) queued up. Let's see how those do.

@darsnack
Copy link
Member

I will say though that the ResNet loss curves don't look as bad as I remember them. Perhaps in this case, a different learning rate would fix things.

@ToucheSir
Copy link
Member

My recollection was that the PyTorch resnets converged faster and overfit less easily even with less help from data augmentation. Is it straightforward to do a head-to-head comparison?

@ToucheSir
Copy link
Member

I'm able to replicate the poor training behaviour of EfficientNet-B0 on a local machine, which happens to have an AMD GPU. this suggests the problem may not lie with anything CUDA-specific.

@IanButterworth
Copy link
Contributor

Great, thanks. How does one go about debugging this kind of thing? Are there generic tools for visualizing the model that could help?

@darsnack
Copy link
Member

I'll modify the script to log gradient norms by layer and also do some local debugging just to sanity check the output isn't obviously wrong.

I'll also add MobileNet to the script. I think that might be a good reference model to compare against assuming it does converge. If it works, that would narrow the issue down to the specific "inverted residual block" that EfficientNet uses.

@ToucheSir
Copy link
Member

If either of you have a machine with a decently powerful CPU, I think a CPU-only run would be interesting to see if we can isolate GPU libraries as a possibility altogether.

@IanButterworth
Copy link
Contributor

I'm giving EfficientNetv2 a go with cpu and -t16. Looks like each epoch will take 3 hours
Screenshot 2023-12-16 at 8 58 09 PM

julia> versioninfo()
Julia Version 1.9.4
Commit 8e5136fa297 (2023-11-14 08:46 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × AMD Ryzen 9 5950X 16-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, znver3)
  Threads: 16 on 32 virtual cores

@ToucheSir
Copy link
Member

I was seeing less time than that for b0 on my decidedly less powerful machine, so maybe using a smaller size would help? We've established that all EfficientNets are affected after all.

@theabhirath
Copy link
Member Author

theabhirath commented Dec 17, 2023

Could we shift the discussion to a GitHub discussion? I'd hate to see this getting lost in a closed PR 😅

P.S. ResNet-18 test accuracies on WandB also look weird (there's no way they should be oscillating so much), so there's definitely something going on:

ResNet

@IanButterworth
Copy link
Contributor

IanButterworth commented Dec 17, 2023

EfficientNet B0 on cpu after 15 epochs
Screenshot 2023-12-17 at 8 51 11 AM

@IanButterworth
Copy link
Contributor

I've updated progress on my benchmark script here #264 (comment)

Would anyone be able to help me get the errors resolved? I guess most of them are input image size issues?

@theabhirath
Copy link
Member Author

theabhirath commented Dec 22, 2023

I've updated progress on my benchmark script here #264 (comment)

Would anyone be able to help me get the errors resolved? I guess most of them are input image size issues?

I've commented there! Also just to understand your previous EfficientNet training on CPU graph, it still seems to have those fluctuations in loss, right? Meaning that the issue might not be exactly GPU linked (or at least not only GPU linked)?

@darsnack
Copy link
Member

Updated results: https://wandb.ai/darsnack/metalhead-bench/

For me, I see MobileNetv1 and MobileNetv2 succeeding while MobileNetv3 fails. I tried to check what the differences are by inspection, and the one major different was the use of hardswish and hardsigmoid.

So, I ran another MobileNetv3 but with hardswish and hardsigmoid replaced by relu. Now, it trains to a reasonable accuracy relative to the others. I will re-run again on the CPU to double check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new-model Request or implementation of a new model
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

4 participants