Skip to content

Commit

Permalink
fix jax cuda version in publish.yml again
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed Feb 21, 2024
1 parent 92f9691 commit 9140339
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jobs:

- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_MAJOR=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_JAX_VERSION=$(echo ${{ matrix.jax-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
Expand Down Expand Up @@ -96,7 +97,7 @@ jobs:
- name: Install Jax ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
pip install --upgrade "jax[cuda${MATRIX_CUDA_VERSION%.*}_local] == ${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda${MATRIX_CUDA_MAJOR}_local] == ${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
shell:
bash

Expand Down

0 comments on commit 9140339

Please sign in to comment.