You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using TPUStrategy and tensorflow backend for training on TPUs with keras 2. Now upgrading to keras 3. Can a comprehensive example please be provided using jax backend for multi-host multi-tpu. Ideally something that scales v4-256 and above.
Interested in data parallel primarily. Have used runtime tpu-vm-tf-2.15.0-pod-pjrt. Have no idea what this needs to be for jax backend. How does user VM, which may be distinct from the TPU hosts connect?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Using
TPUStrategy
andtensorflow
backend for training on TPUs with keras 2. Now upgrading to keras 3. Can a comprehensive example please be provided usingjax
backend for multi-host multi-tpu. Ideally something that scalesv4-256
and above.Interested in data parallel primarily. Have used runtime
tpu-vm-tf-2.15.0-pod-pjrt
. Have no idea what this needs to be forjax
backend. How does user VM, which may be distinct from the TPU hosts connect?Currently that is all handled by code such as:
Beta Was this translation helpful? Give feedback.
All reactions