diff --git a/fx2ait/fx2ait/csrc/AITModel.cpp b/fx2ait/fx2ait/csrc/AITModel.cpp index eed7855c4..510bbb534 100644 --- a/fx2ait/fx2ait/csrc/AITModel.cpp +++ b/fx2ait/fx2ait/csrc/AITModel.cpp @@ -62,8 +62,8 @@ static auto registerAITModel = std::string, std::vector, std::vector, - c10::optional, - c10::optional, + std::optional, + std::optional, int64_t>()) .def("forward", &AITModel::forward) .def("profile", &AITModel::profile) diff --git a/fx2ait/fx2ait/csrc/AITModel.h b/fx2ait/fx2ait/csrc/AITModel.h index 4780b7ed1..01efa1300 100644 --- a/fx2ait/fx2ait/csrc/AITModel.h +++ b/fx2ait/fx2ait/csrc/AITModel.h @@ -25,8 +25,8 @@ class AITModel : public torch::CustomClassHolder { const std::string& model_path, std::vector input_names, std::vector output_names, - c10::optional input_dtype, - c10::optional output_dtype, + std::optional input_dtype, + std::optional output_dtype, int64_t num_runtimes = 2, bool use_cuda_graph = false) : aitModelImpl_( diff --git a/fx2ait/fx2ait/csrc/AITModelImpl.cpp b/fx2ait/fx2ait/csrc/AITModelImpl.cpp index 739e1935e..74586cb6e 100644 --- a/fx2ait/fx2ait/csrc/AITModelImpl.cpp +++ b/fx2ait/fx2ait/csrc/AITModelImpl.cpp @@ -133,8 +133,8 @@ AITModelImpl::AITModelImpl( const std::string& model_path, std::vector input_names, std::vector output_names, - c10::optional input_dtype, - c10::optional output_dtype, + std::optional input_dtype, + std::optional output_dtype, int64_t num_runtimes, bool use_cuda_graph) : handle_(dlopen(model_path.c_str(), RTLD_NOW | RTLD_LOCAL)), diff --git a/fx2ait/fx2ait/csrc/AITModelImpl.h b/fx2ait/fx2ait/csrc/AITModelImpl.h index 063808b83..196c39ae9 100644 --- a/fx2ait/fx2ait/csrc/AITModelImpl.h +++ b/fx2ait/fx2ait/csrc/AITModelImpl.h @@ -43,8 +43,8 @@ class AITModelImpl { const std::string& model_path, std::vector input_names, std::vector output_names, - c10::optional input_dtype, - c10::optional output_dtype, + std::optional input_dtype, + std::optional output_dtype, int64_t num_runtimes = 2, bool use_cuda_graph = false); @@ -102,11 +102,11 @@ class AITModelImpl { return output_names_; } - const c10::optional floatingPointInputDtype() const { + const std::optional floatingPointInputDtype() const { return floating_point_input_dtype_; } - const c10::optional floatingPointOutputDtype() const { + const std::optional floatingPointOutputDtype() const { return floating_point_output_dtype_; } @@ -182,8 +182,8 @@ class AITModelImpl { const std::string library_path_; const std::vector input_names_; const std::vector output_names_; - const c10::optional floating_point_input_dtype_; - const c10::optional floating_point_output_dtype_; + const std::optional floating_point_input_dtype_; + const std::optional floating_point_output_dtype_; #ifdef FBCODE_AIT folly::F14FastMap input_name_to_index_; folly::F14FastMap output_name_to_index_;