Skip to content

Commit

Permalink
dynamo benchmarking
Browse files Browse the repository at this point in the history
ghstack-source-id: 902145d48118591188d0f0b72e1043152e8a6d2f
Pull Request resolved: #265
  • Loading branch information
PaliC committed Nov 11, 2022
1 parent 22ff1ca commit d6f3f18
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 24 deletions.
15 changes: 12 additions & 3 deletions multipy/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ set(INTERPRETER_TEST_SOURCES
set(INTERPRETER_TEST_SOURCES_GPU
${DEPLOY_DIR}/test_deploy_gpu.cpp
)

# TODO: Currently tests can only be done when ABI=1 as the testing infrustructure
# used by ASSERT_TRUE requires ABI=1 in Github actions, we should fix this!
set(INTERPRETER_TEST_SOURCES_COMPAT
${DEPLOY_DIR}/test_deploy_compat.cpp
)

add_executable(test_deploy ${INTERPRETER_TEST_SOURCES})
# target_compile_definitions(test_deploy PUBLIC TEST_CUSTOM_LIBRARY)
Expand All @@ -99,6 +99,15 @@ target_link_libraries(test_deploy
)
target_include_directories(test_deploy PRIVATE ${CMAKE_SOURCE_DIR}/../..)

add_executable(test_deploy_compat ${INTERPRETER_TEST_SOURCES_COMPAT})
# target_compile_definitions(test_deploy_compat PUBLIC TEST_CUSTOM_LIBRARY)
target_include_directories(test_deploy_compat PRIVATE ${PYTORCH_ROOT}/torch)
target_link_libraries(test_deploy_compat
PUBLIC "-Wl,--no-as-needed -rdynamic" gtest dl torch_deploy_interface c10 torch_cpu
)
target_include_directories(test_deploy_compat PRIVATE ${CMAKE_SOURCE_DIR}/../..)


if(BUILD_CUDA_TESTS)
LINK_DIRECTORIES("${PYTORCH_ROOT}/torch/lib")
add_executable(test_deploy_gpu ${INTERPRETER_TEST_SOURCES_GPU})
Expand Down
15 changes: 12 additions & 3 deletions multipy/runtime/example/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ struct Benchmark {
manager.debugLimitInterpreters(1);
} else if (strategy == "multi_python") {
manager.debugLimitInterpreters(n_threads_);
} else if (strategy == "dynamo+deploy"){
manager.debugLimitInterpreters(n_threads_);
}
}

Expand Down Expand Up @@ -295,6 +297,7 @@ int main(int argc, char* argv[]) {
cuda = std::string(argv[2]) == "cuda";
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool jit_enable = std::string(argv[3]) == "jit";
bool inductor_enable = std::string(argv[3]) == "inductor";
Report::report_header(std::cout);
torch::deploy::InterpreterManager manager(max_thread);

Expand All @@ -311,19 +314,25 @@ int main(int argc, char* argv[]) {
if (n_thread > max_thread) {
continue;
}
for (std::string strategy : {"one_python", "multi_python", "jit"}) {
for (std::string strategy : {"one_python", "multi_python", "jit", "dynamo+deploy"}) {
if (strategy == "jit") {
if (!jit_enable) {
continue;
}
if (!exists(model_file + "_jit")) {
continue;
}
}
if (strategy == "one_python") {
} else if (strategy == "one_python") {
Benchmark b(manager, 1, strategy, model_file);
Report r = b.run();
r.report(std::cout);
} else if (strategy == "dynamo+deploy") {
if (!exists(model_file + "_dynamo")) {
continue;
}
Benchmark b(manager, n_thread, strategy, model_file + "_dynamo");
Report r = b.run();
r.report(std::cout);
} else {
Benchmark b(manager, n_thread, strategy, model_file);
Report r = b.run();
Expand Down
5 changes: 5 additions & 0 deletions multipy/runtime/example/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
from torch import Tensor
import torch._dynamo as torchdynamo


class Simple(torch.nn.Module):
Expand Down Expand Up @@ -128,6 +129,10 @@ def forward(self, x):
def resnet18():
return ResNet(BasicBlock, [2, 2, 2, 2])

@torchdynamo.optimize("ts")
def resnet18_dynamo():
return resnet18();


class BatchedModel(nn.Module):
def forward(self, input1: Tensor, input2: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down
45 changes: 27 additions & 18 deletions multipy/runtime/example/generate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
multi_return_metadata,
MultiReturn,
resnet18,
resnet18_dynamo,
Simple,
)
except ImportError:
Expand All @@ -30,6 +31,7 @@
multi_return_metadata,
MultiReturn,
resnet18,
resnet18_dynamo,
Simple,
)

Expand Down Expand Up @@ -60,34 +62,39 @@ def save(
name,
model,
model_jit=None,
model_dynamo=None,
eg=None,
featurestore_meta=None,
text_in_extra_file=None,
binary_in_extra_file=None,
):
with PackageExporter(str(p / name)) as e:
e.mock("iopath.**")
e.intern("**")
e.save_pickle("model", "model.pkl", model)
if eg:
e.save_pickle("model", "example.pkl", eg)
if featurestore_meta:
# TODO(whc) can this name come from buck somehow,
# so it's consistent with predictor_config_constants::METADATA_FILE_NAME()?
e.save_text("extra_files", "metadata.json", featurestore_meta)
if text_in_extra_file:
e.save_text("extra_files", "text", text_in_extra_file)
if binary_in_extra_file:
e.save_binary("extra_files", "binary", binary_in_extra_file)

def package_model(name, model):
with PackageExporter(str(p / name)) as e:
e.mock("iopath.**")
e.intern("**")
e.save_pickle("model", "model.pkl", model)
if eg:
e.save_pickle("model", "example.pkl", eg)
if featurestore_meta:
# TODO(whc) can this name come from buck somehow,
# so it's consistent with predictor_config_constants::METADATA_FILE_NAME()?
e.save_text("extra_files", "metadata.json", featurestore_meta)
if text_in_extra_file:
e.save_text("extra_files", "text", text_in_extra_file)
if binary_in_extra_file:
e.save_binary("extra_files", "binary", binary_in_extra_file)

package_model(name, model)
if model_dynamo:
package_model(name + "_dynamo", model_dynamo)
if model_jit:
model_jit.save(str(p / (name + "_jit")))



parser = argparse.ArgumentParser(description="Generate Examples")
parser.add_argument("--install_dir", help="Root directory for all output files")


if __name__ == "__main__":
args = parser.parse_args()
if args.install_dir is None:
Expand All @@ -98,9 +105,10 @@ def save(

resnet = resnet18()
resnet.eval()
resnet_dynamo = resnet18_dynamo()
resnet_eg = torch.rand(1, 3, 224, 224)
resnet_traced = torch.jit.trace(resnet, resnet_eg)
save("resnet", resnet, resnet_traced, (resnet_eg,))
save("resnet", resnet, resnet_traced, resnet_dynamo, (resnet_eg,))

simple = Simple(10, 20)
save(
Expand All @@ -117,6 +125,7 @@ def save(
"multi_return",
multi_return,
torch.jit.script(multi_return),
None,
(torch.rand(10, 20),),
multi_return_metadata,
)
Expand Down Expand Up @@ -149,4 +158,4 @@ def save(
e.add_dependency("tensorrt")
e.mock("iopath.**")
e.intern("**")
e.save_pickle("make_trt_module", "model.pkl", make_trt_module)
e.save_pickle("make_trt_module", "model.pkl", make_trt_module)
105 changes: 105 additions & 0 deletions multipy/runtime/test_deploy_compat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <ATen/Parallel.h>
#include <gtest/gtest.h>
#include <libgen.h>
#include <cstring>

#include <c10/util/irange.h>
#include <libgen.h>
#include <multipy/runtime/deploy.h>
#include <torch/script.h>
#include <torch/torch.h>

#include <future>
#include <iostream>
#include <string>

void compare_torchpy_jit(const char* model_filename, const char* jit_filename) {
// Test

torch::deploy::InterpreterManager m(2);
torch::deploy::Package p = m.loadPackage(model_filename);
auto model = p.loadPickle("model", "model.pkl");
at::IValue eg;
{
auto I = p.acquireSession();
eg = I.self.attr("load_pickle")({"model", "example.pkl"}).toIValue();
}

at::Tensor output = model(eg.toTupleRef().elements()).toTensor();

// Reference
auto ref_model = torch::jit::load(jit_filename);
at::Tensor ref_output =
ref_model.forward(eg.toTupleRef().elements()).toTensor();
ASSERT_TRUE(ref_output.allclose(output, 1e-03, 1e-05));
}

const char* resnet_path = "multipy/runtime/example/generated/resnet_dynamo";
const char* resnet_jit_path = "multipy/runtime/example/generated/resnet_jit";

const char* path(const char* envname, const char* path) {
const char* e = getenv(envname);
return e ? e : path;
}

TEST(TorchpyTest, ResNetWithDynamo) {
compare_torchpy_jit(
path("RESNET", resnet_path),
path("RESNET_JIT", resnet_jit_path));
}

TEST(TorchpyTest, ThreadedResnetModelWithDynamo) {
size_t nthreads = 3;
torch::deploy::InterpreterManager manager(nthreads);

torch::deploy::Package p = manager.loadPackage(path("RESNET", resnet_path));
auto model = p.loadPickle("model", "model.pkl");
auto ref_model = torch::jit::load(path("RESNET_JIT", resnet_jit_path));

auto input = torch::ones({10, 20});

std::vector<at::Tensor> outputs;

std::vector<std::future<at::Tensor>> futures;
for (const auto i : c10::irange(nthreads)) {
(void)i;
futures.push_back(std::async(std::launch::async, [&model]() {
auto input = torch::ones({10, 10, 10});
for (const auto j : c10::irange(100)) {
(void)j;
model({input.alias()}).toTensor();
}
auto result = model({input.alias()}).toTensor();
return result;
}));
}
for (const auto i : c10::irange(nthreads)) {
outputs.push_back(futures[i].get());
}

// Generate reference
auto ref_output = ref_model.forward({input.alias()}).toTensor();

// Compare all to reference
for (const auto i : c10::irange(nthreads)) {
ASSERT_TRUE(ref_output.equal(outputs[i]));
}
}

int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
char tempeh[256];
getcwd(tempeh, 256);
std::cout << "Current working directory: " << tempeh << std::endl;
int rc = RUN_ALL_TESTS();
char tmp[256];
getcwd(tmp, 256);
std::cout << "Current working directory: " << tmp << std::endl;
return rc;
}

0 comments on commit d6f3f18

Please sign in to comment.