From d6f3f181b2361e8c1297f99553ab085882812893 Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 11 Nov 2022 22:38:32 +0000 Subject: [PATCH] dynamo benchmarking ghstack-source-id: 902145d48118591188d0f0b72e1043152e8a6d2f Pull Request resolved: https://github.com/pytorch/multipy/pull/265 --- multipy/runtime/CMakeLists.txt | 15 ++- multipy/runtime/example/benchmark.cpp | 15 ++- multipy/runtime/example/examples.py | 5 + multipy/runtime/example/generate_examples.py | 45 ++++---- multipy/runtime/test_deploy_compat.cpp | 105 +++++++++++++++++++ 5 files changed, 161 insertions(+), 24 deletions(-) create mode 100644 multipy/runtime/test_deploy_compat.cpp diff --git a/multipy/runtime/CMakeLists.txt b/multipy/runtime/CMakeLists.txt index f4df3ef0..dedf6498 100644 --- a/multipy/runtime/CMakeLists.txt +++ b/multipy/runtime/CMakeLists.txt @@ -87,9 +87,9 @@ set(INTERPRETER_TEST_SOURCES set(INTERPRETER_TEST_SOURCES_GPU ${DEPLOY_DIR}/test_deploy_gpu.cpp ) - -# TODO: Currently tests can only be done when ABI=1 as the testing infrustructure -# used by ASSERT_TRUE requires ABI=1 in Github actions, we should fix this! +set(INTERPRETER_TEST_SOURCES_COMPAT + ${DEPLOY_DIR}/test_deploy_compat.cpp +) add_executable(test_deploy ${INTERPRETER_TEST_SOURCES}) # target_compile_definitions(test_deploy PUBLIC TEST_CUSTOM_LIBRARY) @@ -99,6 +99,15 @@ target_link_libraries(test_deploy ) target_include_directories(test_deploy PRIVATE ${CMAKE_SOURCE_DIR}/../..) +add_executable(test_deploy_compat ${INTERPRETER_TEST_SOURCES_COMPAT}) +# target_compile_definitions(test_deploy_compat PUBLIC TEST_CUSTOM_LIBRARY) +target_include_directories(test_deploy_compat PRIVATE ${PYTORCH_ROOT}/torch) +target_link_libraries(test_deploy_compat + PUBLIC "-Wl,--no-as-needed -rdynamic" gtest dl torch_deploy_interface c10 torch_cpu +) +target_include_directories(test_deploy_compat PRIVATE ${CMAKE_SOURCE_DIR}/../..) + + if(BUILD_CUDA_TESTS) LINK_DIRECTORIES("${PYTORCH_ROOT}/torch/lib") add_executable(test_deploy_gpu ${INTERPRETER_TEST_SOURCES_GPU}) diff --git a/multipy/runtime/example/benchmark.cpp b/multipy/runtime/example/benchmark.cpp index b8ec9755..5dae3a4d 100644 --- a/multipy/runtime/example/benchmark.cpp +++ b/multipy/runtime/example/benchmark.cpp @@ -178,6 +178,8 @@ struct Benchmark { manager.debugLimitInterpreters(1); } else if (strategy == "multi_python") { manager.debugLimitInterpreters(n_threads_); + } else if (strategy == "dynamo+deploy"){ + manager.debugLimitInterpreters(n_threads_); } } @@ -295,6 +297,7 @@ int main(int argc, char* argv[]) { cuda = std::string(argv[2]) == "cuda"; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool jit_enable = std::string(argv[3]) == "jit"; + bool inductor_enable = std::string(argv[3]) == "inductor"; Report::report_header(std::cout); torch::deploy::InterpreterManager manager(max_thread); @@ -311,7 +314,7 @@ int main(int argc, char* argv[]) { if (n_thread > max_thread) { continue; } - for (std::string strategy : {"one_python", "multi_python", "jit"}) { + for (std::string strategy : {"one_python", "multi_python", "jit", "dynamo+deploy"}) { if (strategy == "jit") { if (!jit_enable) { continue; @@ -319,11 +322,17 @@ int main(int argc, char* argv[]) { if (!exists(model_file + "_jit")) { continue; } - } - if (strategy == "one_python") { + } else if (strategy == "one_python") { Benchmark b(manager, 1, strategy, model_file); Report r = b.run(); r.report(std::cout); + } else if (strategy == "dynamo+deploy") { + if (!exists(model_file + "_dynamo")) { + continue; + } + Benchmark b(manager, n_thread, strategy, model_file + "_dynamo"); + Report r = b.run(); + r.report(std::cout); } else { Benchmark b(manager, n_thread, strategy, model_file); Report r = b.run(); diff --git a/multipy/runtime/example/examples.py b/multipy/runtime/example/examples.py index 46cc8d91..316da962 100644 --- a/multipy/runtime/example/examples.py +++ b/multipy/runtime/example/examples.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from torch import Tensor +import torch._dynamo as torchdynamo class Simple(torch.nn.Module): @@ -128,6 +129,10 @@ def forward(self, x): def resnet18(): return ResNet(BasicBlock, [2, 2, 2, 2]) +@torchdynamo.optimize("ts") +def resnet18_dynamo(): + return resnet18(); + class BatchedModel(nn.Module): def forward(self, input1: Tensor, input2: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/multipy/runtime/example/generate_examples.py b/multipy/runtime/example/generate_examples.py index f947f3e1..05c27100 100644 --- a/multipy/runtime/example/generate_examples.py +++ b/multipy/runtime/example/generate_examples.py @@ -21,6 +21,7 @@ multi_return_metadata, MultiReturn, resnet18, + resnet18_dynamo, Simple, ) except ImportError: @@ -30,6 +31,7 @@ multi_return_metadata, MultiReturn, resnet18, + resnet18_dynamo, Simple, ) @@ -60,34 +62,39 @@ def save( name, model, model_jit=None, + model_dynamo=None, eg=None, featurestore_meta=None, text_in_extra_file=None, binary_in_extra_file=None, ): - with PackageExporter(str(p / name)) as e: - e.mock("iopath.**") - e.intern("**") - e.save_pickle("model", "model.pkl", model) - if eg: - e.save_pickle("model", "example.pkl", eg) - if featurestore_meta: - # TODO(whc) can this name come from buck somehow, - # so it's consistent with predictor_config_constants::METADATA_FILE_NAME()? - e.save_text("extra_files", "metadata.json", featurestore_meta) - if text_in_extra_file: - e.save_text("extra_files", "text", text_in_extra_file) - if binary_in_extra_file: - e.save_binary("extra_files", "binary", binary_in_extra_file) - + def package_model(name, model): + with PackageExporter(str(p / name)) as e: + e.mock("iopath.**") + e.intern("**") + e.save_pickle("model", "model.pkl", model) + if eg: + e.save_pickle("model", "example.pkl", eg) + if featurestore_meta: + # TODO(whc) can this name come from buck somehow, + # so it's consistent with predictor_config_constants::METADATA_FILE_NAME()? + e.save_text("extra_files", "metadata.json", featurestore_meta) + if text_in_extra_file: + e.save_text("extra_files", "text", text_in_extra_file) + if binary_in_extra_file: + e.save_binary("extra_files", "binary", binary_in_extra_file) + + package_model(name, model) + if model_dynamo: + package_model(name + "_dynamo", model_dynamo) if model_jit: model_jit.save(str(p / (name + "_jit"))) + parser = argparse.ArgumentParser(description="Generate Examples") parser.add_argument("--install_dir", help="Root directory for all output files") - if __name__ == "__main__": args = parser.parse_args() if args.install_dir is None: @@ -98,9 +105,10 @@ def save( resnet = resnet18() resnet.eval() + resnet_dynamo = resnet18_dynamo() resnet_eg = torch.rand(1, 3, 224, 224) resnet_traced = torch.jit.trace(resnet, resnet_eg) - save("resnet", resnet, resnet_traced, (resnet_eg,)) + save("resnet", resnet, resnet_traced, resnet_dynamo, (resnet_eg,)) simple = Simple(10, 20) save( @@ -117,6 +125,7 @@ def save( "multi_return", multi_return, torch.jit.script(multi_return), + None, (torch.rand(10, 20),), multi_return_metadata, ) @@ -149,4 +158,4 @@ def save( e.add_dependency("tensorrt") e.mock("iopath.**") e.intern("**") - e.save_pickle("make_trt_module", "model.pkl", make_trt_module) + e.save_pickle("make_trt_module", "model.pkl", make_trt_module) \ No newline at end of file diff --git a/multipy/runtime/test_deploy_compat.cpp b/multipy/runtime/test_deploy_compat.cpp new file mode 100644 index 00000000..740577ea --- /dev/null +++ b/multipy/runtime/test_deploy_compat.cpp @@ -0,0 +1,105 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +void compare_torchpy_jit(const char* model_filename, const char* jit_filename) { + // Test + + torch::deploy::InterpreterManager m(2); + torch::deploy::Package p = m.loadPackage(model_filename); + auto model = p.loadPickle("model", "model.pkl"); + at::IValue eg; + { + auto I = p.acquireSession(); + eg = I.self.attr("load_pickle")({"model", "example.pkl"}).toIValue(); + } + + at::Tensor output = model(eg.toTupleRef().elements()).toTensor(); + + // Reference + auto ref_model = torch::jit::load(jit_filename); + at::Tensor ref_output = + ref_model.forward(eg.toTupleRef().elements()).toTensor(); + ASSERT_TRUE(ref_output.allclose(output, 1e-03, 1e-05)); +} + +const char* resnet_path = "multipy/runtime/example/generated/resnet_dynamo"; +const char* resnet_jit_path = "multipy/runtime/example/generated/resnet_jit"; + +const char* path(const char* envname, const char* path) { + const char* e = getenv(envname); + return e ? e : path; +} + +TEST(TorchpyTest, ResNetWithDynamo) { + compare_torchpy_jit( + path("RESNET", resnet_path), + path("RESNET_JIT", resnet_jit_path)); +} + +TEST(TorchpyTest, ThreadedResnetModelWithDynamo) { + size_t nthreads = 3; + torch::deploy::InterpreterManager manager(nthreads); + + torch::deploy::Package p = manager.loadPackage(path("RESNET", resnet_path)); + auto model = p.loadPickle("model", "model.pkl"); + auto ref_model = torch::jit::load(path("RESNET_JIT", resnet_jit_path)); + + auto input = torch::ones({10, 20}); + + std::vector outputs; + + std::vector> futures; + for (const auto i : c10::irange(nthreads)) { + (void)i; + futures.push_back(std::async(std::launch::async, [&model]() { + auto input = torch::ones({10, 10, 10}); + for (const auto j : c10::irange(100)) { + (void)j; + model({input.alias()}).toTensor(); + } + auto result = model({input.alias()}).toTensor(); + return result; + })); + } + for (const auto i : c10::irange(nthreads)) { + outputs.push_back(futures[i].get()); + } + + // Generate reference + auto ref_output = ref_model.forward({input.alias()}).toTensor(); + + // Compare all to reference + for (const auto i : c10::irange(nthreads)) { + ASSERT_TRUE(ref_output.equal(outputs[i])); + } +} + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + char tempeh[256]; + getcwd(tempeh, 256); + std::cout << "Current working directory: " << tempeh << std::endl; + int rc = RUN_ALL_TESTS(); + char tmp[256]; + getcwd(tmp, 256); + std::cout << "Current working directory: " << tmp << std::endl; + return rc; +}