Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating to Keras 3.0 and migrating to PyTorch #418

Merged
merged 26 commits into from
May 28, 2024

Conversation

IgorTatarnikov
Copy link
Member

@IgorTatarnikov IgorTatarnikov commented May 10, 2024

Before submitting a pull request (PR), please read the contributing guide.

Please fill out as much of this template as you can, but if you have any problems or questions, just leave a comment and we will help out :)

Description

What is this PR

  • Bug fix
  • Addition of a new feature
  • Other

Why is this PR needed?
tensorflow has become increasingly difficult to support (e.g. lack of GPU support on native Windows). Switching to keras 3.0 allows us to migrate to using torch as the backend instead of tensorflow. This will make future maintenance easy and allow us to support Python 3.11+.

What does this PR do?
Upgrades to keras 3.0.
Sets torch as the default backend.
Removes any functions related to tensorflow (error suppression etc...)

References

Closes #279
brainglobe/brainglobe.github.io#177
brainglobe/brainglobe.github.io#183
#266

How has this PR been tested?

All tests pass on CI, basic workflows have been tested by manual inspection.

Is this a breaking change?

No.

Does this PR require an update to the documentation?

Yes, see brainglobe/brainglobe.github.io#183

Checklist:

  • The code has been tested locally
  • The documentation has been updated to reflect any changes
  • The code has been formatted with pre-commit

sfmig and others added 11 commits February 7, 2024 13:50
* check if Keras present

* change TF to Keras in CI

* remove comment

* change dependencies in pyproject.toml for Keras 3.0
* remove pytest-lazy-fixture as dev dependency and skip test (with WG temp fix)

* change tensorflow dependency for cellfinder

* replace keras imports from tensorflow to just keras imports

* add keras import and reorder

* add keras and TF 2.16 to pyproject.toml

* comment out TF version check for now

* change checkpoint filename for compliance with keras 3. remove use_multiprocessing=False from fit() as it is no longer an input. test_train() passing

* add multiprocessing parameters to cube generator constructor and remove from fit() signature (keras3 change)

* apply temp garbage collector fix

* skip troublesome test

* skip running tests on CI on windows

* remove commented out TF check

* clean commented out code. Explicitly pass use_multiprocessing=False (as before)

* remove str conversion before model.save

* raise test_detection error for sonarcloud happy

* skip running tests on windows on CI

* remove filename comment and small edits
* change some old references to TF for the import check

* change TF cached model to Keras
* replace tensorflow Tensor with keras tensor

* add case for TF prep in prep_model_weights

* add different backends to pyproject.toml

* add backend configuration to cellfinder init file. tests passing with jax locally

* define extra dependencies for cellfinder with different backends. run tox with TF backend

* run tox using TF and JAX backend

* install TF in brainmapper environment before running tests in CI

* add backends check to cellfinder init file

* clean up comments

* fix tf-nightly import check

* specify TF backend in include guard check

* clarify comment

* remove 'backend' from dependencies specifications

* Apply suggestions from code review

Co-authored-by: Igor Tatarnikov <[email protected]>

---------

Co-authored-by: Igor Tatarnikov <[email protected]>
* use jax backend in brainmapper tests in CI

* skip TF backend on windows

* fix pip install cellfinder for brainmapper CI tests

* add keras env variable for brainmapper CLI tests

* fix prep_model_weights
* replace tensorflow Tensor with keras tensor

* add case for TF prep in prep_model_weights

* add different backends to pyproject.toml

* add backend configuration to cellfinder init file. tests passing with jax locally

* define extra dependencies for cellfinder with different backends. run tox with TF backend

* run tox using TF and JAX backend

* install TF in brainmapper environment before running tests in CI

* add backends check to cellfinder init file

* clean up comments

* fix tf-nightly import check

* specify TF backend in include guard check

* clarify comment

* remove 'backend' from dependencies specifications

* Apply suggestions from code review

Co-authored-by: Igor Tatarnikov <[email protected]>

* PyTorch runs utilizing multiple cores

* PyTorch fix with default models

* Tests run on every push for now

* Run test on torch backend only

* Fixed guard test to set torch as KERAS_BACKEND

* KERAS_BACKEND env variable set directly in test_include_guard.yaml

* Run test on python 3.11

* Remove tf-nightly from __init__ version check

* Added 3.11 to legacy tox config

* Changed legacy tox config for real this time

* Don't set the wrong max_processing value

* Torch is now set as the default backend

* Tests only run with torch, updated comments

* Unpinned torch version

* Add codecov token (#403)

* add codecov token

* generate xml coverage report

* add timeout to testing jobs

* Allow turning off classification or detection in GUI (#402)

* Allow turning off classification or detection in GUI.

* Fix test.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor to fix code analysis errors.

* Ensure array is always 2d.

* Apply suggestions from code review

Co-authored-by: Igor Tatarnikov <[email protected]>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Igor Tatarnikov <[email protected]>

* Support single z-stack tif file for input (#397)

* Support single z-stack tif file for input.

* Fix commit hook.

* Apply review suggestions.

* Remove modular asv benchmarks (#406)

* remove modular asv benchmarks

* recover old structure

* remove asv-specific lines from gitignore and manifest

* prune benchmarks

* Adapt CI so it covers both new and old Macs, and installs required additional dependencies on M1 (#408)

* naive attempt at adapting to silicon mac CI

* run include guard test on Silicon CI

* double-check hdf5 is needed

* Optimize cell detection (#398) (#407)

* Replace coord map values with numba list/tuple for optim.

* Switch to fortran layout for faster update of last dim.

* Cache kernel.

* jit ball filter.

* Put z as first axis to speed z rolling (row-major memory).

* Unroll recursion (no perf impact either way).

* Parallelize cell cluster splitting.

* Parallelize walking for full images.

* Cleanup docs and pep8 etc.

* Add pre-commit fixes.

* Fix parallel always being selected and numba function 1st class warning.

* Run hook.

* Older python needs Union instead of |.

* Accept review suggestion.



* Address review changes.

* num_threads must be an int.

---------

Co-authored-by: Matt Einhorn <[email protected]>

* [pre-commit.ci] pre-commit autoupdate (#412)

updates:
- [github.com/pre-commit/pre-commit-hooks: v4.5.0 → v4.6.0](pre-commit/pre-commit-hooks@v4.5.0...v4.6.0)
- [github.com/astral-sh/ruff-pre-commit: v0.3.5 → v0.4.3](astral-sh/ruff-pre-commit@v0.3.5...v0.4.3)
- [github.com/psf/black: 24.3.0 → 24.4.2](psf/black@24.3.0...24.4.2)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: sfmig <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Simplify model download (#414)

* Simplify model download

* Update model cache

* Remove jax and tf tests

* Standardise the data types for inputs to all be float32

* Force torch to use CPU on arm based macOS during tests

* Added PYTORCH_MPS_HIGH_WATERMARK_RATION env variable

* Set env variables in test setup

* Try to set the default device to cpu in the test itself

* Add device call to Conv3D to force cpu

* Revert changes, request one cpu left free

* Revers the numb cores, don't use arm based mac runner

* Merged main, removed torch flags on cellfinder install for guards and brainmapper

* Lowercase Torch

* Change cache directory

---------

Co-authored-by: sfmig <[email protected]>
Co-authored-by: Kimberly Meechan <[email protected]>
Co-authored-by: Matt Einhorn <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Alessandro Felder <[email protected]>
Co-authored-by: Adam Tyson <[email protected]>
# Conflicts:
#	.github/workflows/test_and_deploy.yml
#	.github/workflows/test_include_guard.yaml
#	cellfinder/core/main.py
#	cellfinder/core/tools/prep.py
#	cellfinder/core/train/train_yml.py
#	tests/core/conftest.py
@IgorTatarnikov
Copy link
Member Author

IgorTatarnikov commented May 10, 2024

To do:
Documentation:

Sanity checks for a "regular" cellfinder workflow (should be done via the napari GUI and via the brainmapper CLI where possible):

  • Detect using the default cellfinder model
  • Detect using a custom model
  • Curate a set of cells and use it to retrain the default model
  • Detect using the new updated model

@IgorTatarnikov
Copy link
Member Author

Tests are currently not running mac-latest since I was having issues getting torch to behave on CI, see here. I couldn't find an elegant way of forcing torch to not use the mps device on CI specifically while still allowing it to be used normally.

Not sure how to proceed there. Tests pass locally when run on my personal machine (M2 MacBook Pro running macOS 14.4.1).

@adamltyson
Copy link
Member

Tests are currently not running mac-latest since I was having issues getting torch to behave on CI, see here. I couldn't find an elegant way of forcing torch to not use the mps device on CI specifically while still allowing it to be used normally.
Not sure how to proceed there. Tests pass locally when run on my personal machine (M2 MacBook Pro running macOS 14.4.1).

Naively, it seems like it should be possible (we're not the only people using torch & GH actions!). Ofc you've looked into it, so it's not simple.

I would suggest:

  1. Give it a bit more of a go to see if you can fix it (there's a large online torch community)
  2. If 1. fails, write up an issue with what you've tried and don't let this PR get derailed by a relatively small problem

@IgorTatarnikov IgorTatarnikov marked this pull request as ready for review May 23, 2024 13:33
@IgorTatarnikov IgorTatarnikov requested a review from a team May 23, 2024 13:33
@adamltyson adamltyson requested review from adamltyson and removed request for a team May 23, 2024 13:33
cellfinder/__init__.py Outdated Show resolved Hide resolved
@@ -35,5 +35,5 @@ def test_train(tmpdir):
sys.argv = train_args
train_run()

model_file = os.path.join(tmpdir, "model.h5")
model_file = os.path.join(tmpdir, "model.keras")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the new default extension? This should be added to brainglobe/brainglobe.github.io#189 (I think there's a couple of places in the docs that reference the .h5 files directly).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're saving the whole model then we have to use the .keras extension. For weights, the files must now end with .weights.h5. I'll double check for references to .h5 in our documentation.

@adamltyson
Copy link
Member

@IgorTatarnikov I compared this PR with version 1.2.0 and the classification looks considerably worse. This was using the napari plugin, all parameters default, using the pre-trained model.

The only other difference was the Python version (old cellfinder env is 3.10, and the torch version is 3.12), but that shouldn't affect this.

Do you know what's going on?

TensorFlow
tensorflow

Torch
torch

@adamltyson
Copy link
Member

Also, if I try to run pytest locally, I get a huge stacktrace that leads to a keras error: ModuleNotFoundError: No module named 'tensorflow'

@adamltyson
Copy link
Member

The classification problem seems to be resolved on that machine by deleting and re-downloading the model, not sure what's going on.

Also, if I try to run pytest locally, I get a huge stacktrace that leads to a keras error: ModuleNotFoundError: No module named 'tensorflow'

This issue remains, the backend needs to be set explictly by editing the ~.keras/keras.json file (I needed to do this on two machines). Can we override this?

@adamltyson
Copy link
Member

On my mac (m1), I get this error:

E NotImplementedError: Exception encountered when calling MaxPooling3D.call().
E
E The operator 'aten::max_pool3d_with_indices' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on pytorch/pytorch#77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

I can fix it for detection by setting PYTORCH_ENABLE_MPS_FALLBACK=1, however this isn't ideal because a) others will just see a big red error, and b) presumably we want to use MPS?

Also, even then, upon training, I get this error:

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

@adamltyson
Copy link
Member

After some more testing, everything seems to work fine on my mac, the only issue is automatically setting the backend to torch.

Copy link
Member

@adamltyson adamltyson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks @IgorTatarnikov @sfmig!

I've tested it on macOS, Ubuntu & Windows now. It works fine across all OSs, but it does seem to be very slow on Windows. I think this is out of scope for this review, so I've raised #426 so we can look into it later on.

I think the only remaining thing is setting the backend automatically. If this can be sorted, we can create an rc and get some more people to test upgrading.

@IgorTatarnikov
Copy link
Member Author

I'm not sure why the automatic backend setting isn't working as it stands. I'll remove the conditional and always set the KERAS_BACKEND variable to be torch when the package is imported. This removes some flexibility but we weren't officially planning to support custom backends so if someone needs it they may have to edit the package code locally, or fork it.

@adamltyson
Copy link
Member

Is it worth getting rid of the other code in the init file now? There's not a lot of point setting the backend to torch, then checking whether it's TF or JAX.

@adamltyson
Copy link
Member

I think this is ready to merge then @IgorTatarnikov unless there's anything else you want to do?

@IgorTatarnikov
Copy link
Member Author

I think it's ready! Once this is merged we can merge this PR in brainglobe-workflows to fix the brainmapper tests

@IgorTatarnikov IgorTatarnikov merged commit cbdecaf into main May 28, 2024
16 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Consider moving away from TensorFlow
3 participants