Skip to content

Commit

Permalink
Fixes to matrix-testing scripts (#2730)
Browse files Browse the repository at this point in the history
Summary:
- Fix post-install variant checks for PyTorch PIP package installs

Pull Request resolved: #2730

Reviewed By: spcyppt

Differential Revision: D58503555

Pulled By: q10

fbshipit-source-id: 89186c0b0668286ccabe2f59cc8498e081d597b2
  • Loading branch information
q10 committed Jun 13, 2024
1 parent 6f3b3f8 commit 1dbcad9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
52 changes: 34 additions & 18 deletions .github/scripts/utils_pip.bash
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,39 @@ __prepare_pip_arguments () {
__export_pip_arguments "$([ "$package_variant_type_version" != "" ] && echo "true" || echo "false")"
}

__check_package_variant () {
# shellcheck disable=SC2155
local env_prefix=$(env_name_or_prefix "${env_name}")

# Check applies to installation of packages with variants, and only to non-CPU variants
if [ "$package_variant_type_version" != "" ] && [ "$package_variant_type" != "cpu" ]; then
# Ensure that the package build is of the correct variant
# This test usually applies to the nightly builds
# shellcheck disable=SC2086
if conda run ${env_prefix} pip list | grep "${package_name_raw}" | grep "${package_variant}"; then
local check_passed=1
elif conda run ${env_prefix} pip list | grep "${package_name}" | grep "${package_variant}"; then
local check_passed=1
else
local check_passed=0
fi

if [ $check_passed -eq 1 ]; then
echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] is the correct variant (${package_variant})"
return 0
else
echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] appears to be an incorrect variant as it is missing references to ${package_variant}!"
echo "[CHECK] This can happen if the variant of the package (e.g. GPU, nightly) for the MAJOR.MINOR version of CUDA or ROCm presently installed on the system is not available."
return 1
fi
fi
}

install_from_pytorch_pip () {
local env_name="$1"
local package_name_raw="$2"
local package_channel_version="$3"
local package_variant_type_version="$4"
env_name="$1"
package_name_raw="$2"
package_channel_version="$3"
package_variant_type_version="$4"
if [ "$package_channel_version" == "" ]; then
echo "Usage: ${FUNCNAME[0]} ENV_NAME PACKAGE_NAME PACKAGE_CHANNEL[/VERSION] [PACKAGE_VARIANT_TYPE[/VARIANT_VERSION]]"
echo "Example(s):"
Expand Down Expand Up @@ -203,22 +231,10 @@ install_from_pytorch_pip () {
# shellcheck disable=SC2086
(exec_with_retries 3 conda run ${env_prefix} pip install ${pip_package} --index-url ${pip_channel}) || return 1

# Check applies to installation of packages with variants, and only to non-CPU variants
if [ "$package_variant_type_version" != "" ] && [ "$package_variant_type" != "cpu" ]; then
# Ensure that the package build is of the correct variant
# This test usually applies to the nightly builds
# shellcheck disable=SC2086
if conda run ${env_prefix} pip list | grep "${package_name}" | grep "${package_variant}"; then
echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] is the correct variant (${package_variant})"
else
echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] appears to be an incorrect variant as it is missing references to ${package_variant}!"
echo "[CHECK] This can happen if the variant of the package (e.g. GPU, nightly) for the MAJOR.MINOR version of CUDA or ROCm presently installed on the system is not available."
return 1
fi
fi
# Ensure that the correct package variant has been installed
__check_package_variant || return 1
}


################################################################################
# PyTorch PIP Download Functions
################################################################################
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def gqa_reference(

class Int4GQATest(unittest.TestCase):
@unittest.skipIf(
not torch.version.cuda,
"Skip when CUDA is not available",
not torch.version.cuda or torch.cuda.get_device_capability()[0] < 8,
"Skip when CUDA is not available or CUDA compute capability is less than 8",
)
@settings(verbosity=VERBOSITY, max_examples=40, deadline=None)
# pyre-ignore
Expand Down

0 comments on commit 1dbcad9

Please sign in to comment.