Skip to content

Commit

Permalink
Remove FP64 from TBE CPU tests (pytorch#2049)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2049

FP64 is no longer supported in TBE CPU since D48895311

Reviewed By: spcyppt

Differential Revision: D49746255

fbshipit-source-id: 4976cda0233b1d276b5ea628eff4cb9075943aff
  • Loading branch information
sryap authored and facebook-github-bot committed Sep 29, 2023
1 parent 09d4f16 commit 32c969f
Showing 1 changed file with 2 additions and 17 deletions.
19 changes: 2 additions & 17 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,10 +1453,6 @@ def test_backward_dense( # noqa C901
)

per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
if use_cpu:
# NOTE: GPU version of DenseTableBatchedEmbeddingBagsCodegen doesn't support double.
cc = cc.double()
per_sample_weights = per_sample_weights.double()
per_sample_weights.requires_grad = True
indices.requires_grad = False
offsets.requires_grad = False
Expand Down Expand Up @@ -1494,13 +1490,8 @@ def test_backward_dense( # noqa C901
)

per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
if use_cpu:
# NOTE: GPU version of DenseTableBatchedEmbeddingBagsCodegen doesn't support double.
cc = cc.double()
per_sample_weights = per_sample_weights.double()
else:
cc = cc.float()
per_sample_weights = per_sample_weights.float()
cc = cc.float()
per_sample_weights = per_sample_weights.float()
per_sample_weights.requires_grad = True
indices.requires_grad = False
offsets.requires_grad = False
Expand Down Expand Up @@ -2531,10 +2522,6 @@ def execute_backward_adagrad_( # noqa C901
output_dtype=output_dtype,
)
per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
if use_cpu:
# NOTE: GPU version of SplitTableBatchedEmbeddingBagsCodegen doesn't support double.
cc = cc.double()
per_sample_weights = per_sample_weights.double()
per_sample_weights.requires_grad = True
indices.requires_grad = False
offsets.requires_grad = False
Expand All @@ -2552,8 +2539,6 @@ def execute_backward_adagrad_( # noqa C901
)

per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
if use_cpu:
per_sample_weights = per_sample_weights.double()
per_sample_weights.requires_grad = True
indices.requires_grad = False
offsets.requires_grad = False
Expand Down

0 comments on commit 32c969f

Please sign in to comment.