Skip to content

Commit

Permalink
add cuda_graph for mts_gpu_benchmark
Browse files Browse the repository at this point in the history
Summary:
Add cuda_graph enablement for AIT. Also leave the hook for AOTI

GPU trace after enabling cudagraph: https://fburl.com/perfdoctor/yybf0z60

log information for verification
I0710 17:58:44.425707 2013974 AITModelImpl.cpp:148] AITModelImpl: loading .so lib /tmp/benchmark_529602.1720659401/_run_on_acc_0/_run_on_acc_0-ait_engine.so
I0710 17:58:44.425731 2013974 AITModelImpl.cpp:149] AITModelImpl: num_runtimes: 1,use_cuda_graph: 1

Reviewed By: guowentian

Differential Revision: D59617284
  • Loading branch information
frank-wei authored and facebook-github-bot committed Jul 12, 2024
1 parent a1769f2 commit 67031f7
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 1 deletion.
1 change: 1 addition & 0 deletions fx2ait/fx2ait/ait_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _lower_model_to_backend(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interpreter_result,
)
Expand Down
4 changes: 3 additions & 1 deletion fx2ait/fx2ait/csrc/AITModelImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ AITModelImpl::AITModelImpl(
floating_point_input_dtype_(input_dtype),
floating_point_output_dtype_(output_dtype),
use_cuda_graph_(use_cuda_graph) {
LOG(INFO) << "Loading .so lib " << model_path;
LOG(INFO) << "AITModelImpl: loading .so lib " << model_path;
LOG(INFO) << "AITModelImpl: num_runtimes: " << num_runtimes
<< ",use_cuda_graph: " << use_cuda_graph;
TORCH_CHECK(handle_, "could not dlopen ", model_path, ": ", dlerror());
TORCH_CHECK(num_runtimes > 0, "num_runtimes must be positive");

Expand Down
2 changes: 2 additions & 0 deletions fx2ait/fx2ait/test/test_fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _test_fx2ait_impl(self, test_serialization=False, test_cuda_graph=False):
torch.float16,
torch.float16,
1, # num_runtimes
False,
)
)
ait_mod.engine.use_cuda_graph = test_cuda_graph
Expand Down Expand Up @@ -140,6 +141,7 @@ def forward(self, a, b, c, d):
torch.float16,
torch.float16,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down
1 change: 1 addition & 0 deletions fx2ait/fx2ait/tools/ait_minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def lower_mod_default(
torch.float16,
torch.float16,
1, # num_runtimes
False,
),
interpreter_result,
)
Expand Down
3 changes: 3 additions & 0 deletions fx2ait/fx2ait/tools/common_aten2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def run_test(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -256,6 +257,7 @@ def run_test_with_dynamic_shape(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -375,6 +377,7 @@ def benchmark(f, args):
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down
3 changes: 3 additions & 0 deletions fx2ait/fx2ait/tools/common_fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def run_test(
torch_dtype,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -329,6 +330,7 @@ def run_test_with_dynamic_shape(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -467,6 +469,7 @@ def benchmark(f, args):
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down

0 comments on commit 67031f7

Please sign in to comment.