Skip to content

Commit

Permalink
Fix OSS test (pytorch#2750)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#2750

Reviewed By: q10

Differential Revision: D58705747

fbshipit-source-id: bb29a0915ab7d31c4422f95c20ef62ac67edf50f
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jun 18, 2024
1 parent 9834c54 commit 278a510
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
9 changes: 6 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,8 @@ std::vector<at::Tensor> quantize_fp8_per_col(
std::vector<at::Tensor> quantize_fp8_per_tensor(
at::Tensor input,
std::optional<at::Tensor> bs, // batch size
std::optional<at::Tensor> scale_ub) { // scale upperbound
std::optional<at::Tensor> scale_ub,
bool stochastic_rounding) { // scale upperbound
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
Expand All @@ -1292,15 +1293,17 @@ std::vector<at::Tensor> quantize_fp8_per_row(
at::Tensor input,
std::optional<at::Tensor> bs, // batch size
std::optional<at::Tensor> scale_ub, // scale upperbound
std::optional<c10::ScalarType> output_dtype) { // quantization type
std::optional<c10::ScalarType> output_dtype,
bool stochastic_rounding) { // quantization type
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
at::Tensor quantize_fp8_per_tensor_fixed_scale(
at::Tensor input,
at::Tensor scale,
std::optional<at::Tensor> bs) { // batch size
std::optional<at::Tensor> bs,
bool stochastic_rounding) { // batch size
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/test/quantize/mx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import struct
from enum import Enum, IntEnum
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -78,7 +78,7 @@ def _get_max_norm(ebits: int, mbits: int) -> float:


def _get_format_params( # noqa
fmt: ElemFormat | str | None,
fmt: Union[ElemFormat, str, None],
) -> Tuple[int, int, int, float, float]:
"""Allowed formats:
- intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
Expand Down Expand Up @@ -546,7 +546,7 @@ def _quantize_elemwise_core(

def _quantize_elemwise(
A: torch.Tensor,
elem_format: ElemFormat | None,
elem_format: Union[ElemFormat, None],
round: str = "nearest",
custom_cuda: bool = False,
saturate_normals: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import hypothesis.strategies as st

import torch
from deeplearning.fbgemm.fbgemm_gpu.test.quantize.mx.common import check_diff_quantize

from hypothesis import given, settings, Verbosity

Expand All @@ -25,6 +24,7 @@
_shared_exponents,
_undo_reshape_to_blocks,
all_encodings,
check_diff_quantize,
)

if open_source:
Expand Down

0 comments on commit 278a510

Please sign in to comment.