diff --git a/xdl/CMakeLists.txt b/xdl/CMakeLists.txt
index 639a207c..7d4df997 100644
--- a/xdl/CMakeLists.txt
+++ b/xdl/CMakeLists.txt
@@ -1,10 +1,12 @@
project(xdl)
-cmake_minimum_required(VERSION 2.8)
+cmake_minimum_required(VERSION 3.0.0 FATAL_ERROR)
+SET(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
include(${PROJECT_SOURCE_DIR}/cmake/Utils.cmake)
xdl_option(test "Build all tests." ON)
xdl_option(USE_GPU "use gpu" off)
+xdl_option(USE_PS_PLUS "use ps-plus" ON)
xdl_option(coverage "Generate coverage analysis" off)
enable_testing()
@@ -46,12 +48,21 @@ include_directories(
include(cmake/Dependencies.cmake)
+set(TBB_ROOT "${PROJECT_SOURCE_DIR}/third_party/tbb/")
+include(${TBB_ROOT}/cmake/TBBBuild.cmake)
+set(CMAKE_CXX_FLAGS "-std=c++1y ${CMAKE_CXX_FLAGS}")
+tbb_build(TBB_ROOT ${TBB_ROOT} CONFIG_DIR TBB_DIR MAKE_ARGS)
+find_package(TBB REQUIRED)
+include_directories("${TBB_ROOT}/include")
+
IF (USE_GPU)
+ set(CUDA_PATH "/usr/local/cuda-9.0/")
set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_PATH})
find_package(CUDA REQUIRED)
message("-- CUDA_PATH = ${CUDA_PATH} ")
include(${PROJECT_SOURCE_DIR}/cmake/Cuda.cmake)
include_directories(
+ ${CUDA_PATH}/targets/x86_64-linux/include/
${CUDA_PATH}/include/
)
link_directories(
@@ -62,7 +73,7 @@ ENDIF ()
# PS_PLUS
set(SEASTAR_LIBRARYS -Wl,--whole-archive seastar_service ps_network_static seastar -Wl,--no-whole-archive -L/usr/local/lib64/boost -lboost_timer -lboost_chrono -laio -lboost_program_options -lboost_system -lboost_filesystem -lm -lboost_thread -lcryptopp -lrt -lgnutls -lgnutlsxx -llz4 -ldl -lgcc_s -lunwind -lhwloc -lnuma -lpciaccess -lxml2 -lz -lcares-seastar libstdc++.a)
-set(PS_LIBRARYS -Wl,--whole-archive libzookeeper.a libhashtable.a ps_common ps_client ps_server ps_model_server ps_plugin_hdfs libevent_core.a glog -Wl,--no-whole-archive ${SEASTAR_LIBRARYS})
+set(PS_LIBRARYS -Wl,--whole-archive libzookeeper.a libhashtable.a ps_common ps_client ps_server ps_scheduler ps_model_server ps_plugin_hdfs libevent_core.a -Wl,--no-whole-archive ${SEASTAR_LIBRARYS})
include_directories(${PROJECT_SOURCE_DIR}/ps-plus/)
include_directories(${PROJECT_SOURCE_DIR}/third_party/zookeeper-client/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/zookeeper-client/generated)
@@ -96,11 +107,12 @@ IF (MXNET_BACKEND)
ENDIF()
IF (USE_GPU)
- set(XDL_CORE_DEPEND_LIB libprotobuf ${PS_LIBRARYS} ${BACKEND_LIB} python2.7 cudart dl)
+ set(XDL_CORE_DEPEND_LIB libprotobuf ${PS_LIBRARYS} ${BACKEND_LIB} python2.7 cudart dl ${TBB_IMPORTED_TARGETS})
ELSE ()
- set(XDL_CORE_DEPEND_LIB libprotobuf ${PS_LIBRARYS} ${BACKEND_LIB} python2.7 dl)
+ set(XDL_CORE_DEPEND_LIB libprotobuf ${PS_LIBRARYS} ${BACKEND_LIB} python2.7 dl ${TBB_IMPORTED_TARGETS})
ENDIF ()
+set(XDL_IO_DEPEND_LIB dl rdkafka++ jsoncpp)
set(XDL_CORE_LIB -Wl,--whole-archive xdl_core -Wl,--no-whole-archive ${XDL_CORE_DEPEND_LIB})
set(XDL_IO_LIB -Wl,--whole-archive xdl_io -Wl,--no-whole-archive)
diff --git a/xdl/cmake/FindNumPy.cmake b/xdl/cmake/FindNumPy.cmake
new file mode 100644
index 00000000..229ff63c
--- /dev/null
+++ b/xdl/cmake/FindNumPy.cmake
@@ -0,0 +1,42 @@
+# Find the Python NumPy package
+# PYTHON_NUMPY_INCLUDE_DIR
+# PYTHON_NUMPY_FOUND
+# will be set by this script
+
+cmake_minimum_required(VERSION 2.6)
+
+if(NOT PYTHON_EXECUTABLE)
+ if(NumPy_FIND_QUIETLY)
+ find_package(PythonInterp QUIET)
+ else()
+ find_package(PythonInterp)
+ set(__numpy_out 1)
+ endif()
+endif()
+
+if (PYTHON_EXECUTABLE)
+ # Find out the include path
+ execute_process(
+ COMMAND "${PYTHON_EXECUTABLE}" -c
+ "from __future__ import print_function\ntry: import numpy; print(numpy.get_include(), end='')\nexcept:pass\n"
+ OUTPUT_VARIABLE __numpy_path)
+ # And the version
+ execute_process(
+ COMMAND "${PYTHON_EXECUTABLE}" -c
+ "from __future__ import print_function\ntry: import numpy; print(numpy.__version__, end='')\nexcept:pass\n"
+ OUTPUT_VARIABLE __numpy_version)
+elseif(__numpy_out)
+ message(STATUS "Python executable not found.")
+endif(PYTHON_EXECUTABLE)
+
+find_path(PYTHON_NUMPY_INCLUDE_DIR numpy/arrayobject.h
+ HINTS "${__numpy_path}" "${PYTHON_INCLUDE_PATH}" NO_DEFAULT_PATH)
+
+if(PYTHON_NUMPY_INCLUDE_DIR)
+ set(PYTHON_NUMPY_FOUND 1 CACHE INTERNAL "Python numpy found")
+endif(PYTHON_NUMPY_INCLUDE_DIR)
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(NumPy REQUIRED_VARS PYTHON_NUMPY_INCLUDE_DIR
+ VERSION_VAR __numpy_version)
+
diff --git a/xdl/cmake/Utils.cmake b/xdl/cmake/Utils.cmake
index 2f410602..d04165a6 100644
--- a/xdl/cmake/Utils.cmake
+++ b/xdl/cmake/Utils.cmake
@@ -414,6 +414,26 @@ function(xdl_add_test dir extension)
endforeach()
endfunction()
+function(xdl_add_test_exclude dir extension exclude)
+ file(GLOB TEST_SOURCE "${dir}/*${extension}")
+ foreach(file ${TEST_SOURCE})
+ get_filename_component(file_name ${file} NAME)
+ get_filename_component(file_exe ${file} NAME_WE)
+ if ("${file_name}" STREQUAL "${exclude}" OR "${file_exe}" STREQUAL "${exclude}")
+ message(STATUS "Skip " ${exclude})
+ continue()
+ endif()
+ #message(${extension})
+ #message(${file})
+ #message(${file_exe})
+ #message(${exclude})
+ add_executable(${file_exe} ${file})
+ target_link_libraries(${file_exe} ${ARGN} gcov)
+ set_target_properties(${file_exe} PROPERTIES COMPILE_FLAGS "-g -O0 --coverage")
+ add_test(${file_exe} ${file_exe} COMMAND ${file_exe})
+ endforeach()
+endfunction()
+
function(xdl_add_cuda_test dir extension)
file(GLOB TEST_SOURCE "${dir}/*${extension}")
foreach(file ${TEST_SOURCE})
diff --git a/xdl/distributed/install_xdl_submit.sh b/xdl/distributed/install_xdl_submit.sh
index b8c03ab1..739fa44e 100644
--- a/xdl/distributed/install_xdl_submit.sh
+++ b/xdl/distributed/install_xdl_submit.sh
@@ -22,4 +22,6 @@ fi
cd $(dirname ${BASH_SOURCE[0]})
cp ./xdl_submit/xdl_submit.py /usr/bin/xdl_submit.py
+chmod 777 /usr/bin/xdl_submit.py
cp ./xdl_yarn_scheduler/bin/xdl-yarn-scheduler-1.0.0-SNAPSHOT-jar-with-dependencies.jar /usr/bin/xdl-yarn-scheduler-1.0.0-SNAPSHOT-jar-with-dependencies.jar
+chmod 777 /usr/bin/xdl-yarn-scheduler-1.0.0-SNAPSHOT-jar-with-dependencies.jar
diff --git a/xdl/distributed/xdl_yarn_scheduler/bin/xdl-yarn-scheduler-1.0.0-SNAPSHOT-jar-with-dependencies.jar b/xdl/distributed/xdl_yarn_scheduler/bin/xdl-yarn-scheduler-1.0.0-SNAPSHOT-jar-with-dependencies.jar
index 4ae183e0..b106e363 100644
Binary files a/xdl/distributed/xdl_yarn_scheduler/bin/xdl-yarn-scheduler-1.0.0-SNAPSHOT-jar-with-dependencies.jar and b/xdl/distributed/xdl_yarn_scheduler/bin/xdl-yarn-scheduler-1.0.0-SNAPSHOT-jar-with-dependencies.jar differ
diff --git a/xdl/distributed/xdl_yarn_scheduler/pom.xml b/xdl/distributed/xdl_yarn_scheduler/pom.xml
index 0b007d8b..43ce157a 100644
--- a/xdl/distributed/xdl_yarn_scheduler/pom.xml
+++ b/xdl/distributed/xdl_yarn_scheduler/pom.xml
@@ -76,7 +76,7 @@
com.alibaba
fastjson
- 1.2.28
+ 1.2.58
diff --git a/xdl/examples/mnist/mnist.py b/xdl/examples/mnist/mnist.py
index c14f4a34..bef5484f 100644
--- a/xdl/examples/mnist/mnist.py
+++ b/xdl/examples/mnist/mnist.py
@@ -80,7 +80,7 @@ def model(images, labels):
@xdl.tf_wrapper(is_training=False)
def eval_model(images, labels):
- with tf.variable_scope("train", reuse=True):
+ with tf.variable_scope("train", reuse=tf.AUTO_REUSE):
eval_y = fc(images, [784, 10], [10])
labels_test = tf.cast(labels, tf.int64)
correct_prediction = tf.equal(tf.argmax(eval_y, 1), labels_test)
diff --git a/xdl/ps-plus/CMakeLists.txt b/xdl/ps-plus/CMakeLists.txt
index cb896f6b..4761d25a 100755
--- a/xdl/ps-plus/CMakeLists.txt
+++ b/xdl/ps-plus/CMakeLists.txt
@@ -1,28 +1,90 @@
-cmake_minimum_required(VERSION 2.8)
+cmake_minimum_required(VERSION 3.0.0 FATAL_ERROR)
project(ps-plus)
if (DEBUG)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -fPIC -std=c++1y -D_GLIBCXX_USE_CXX11_ABI=0")
else ()
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -fPIC -std=c++1y -g -DNDEBUG -D_GLIBCXX_USE_CXX11_ABI=0")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -fPIC -std=c++1y -g -DNDEBUG -D_GLIBCXX_USE_CXX11_ABI=0")
endif ()
if (APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
elseif (UNIX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
-endif()
+endif ()
+
+include_directories(.)
+
+set(CMAKE_SKIP_BUILD_RPATH FALSE)
+set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
+set(CMAKE_INSTALL_RPATH "/usr/local/lib:/usr/local/lib64/boost:/usr/local/gcc-5.3.0/lib64")
+set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
+
+# build third_party libraries using cmake
+function (third_party_library_builder_cmake arg)
+list(LENGTH ARGV argv_len)
+ set(i 0)
+ while( i LESS ${argv_len})
+ list(GET ARGV ${i} argv_value)
+ message(STATUS "start build third_party library:${argv_value}")
+ execute_process(COMMAND bash -c "cd ${PROJECT_SOURCE_DIR}/third_party/${argv_value}; mkdir -p build; cd build; cmake .. -DCMAKE_INSTALL_PREFIX=. -DCMAKE_CXX_FLAGS=-D_GLIBCXX_USE_CXX11_ABI=0; make; make install; cd ${PROJECT_SOURCE_DIR}")
+ IF(EXISTS "${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/include")
+ include_directories("${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/include")
+ ENDIF()
+ IF(EXISTS "${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/lib")
+ link_directories("${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/lib")
+ ENDIF()
+ IF(EXISTS "${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/lib64")
+ link_directories("${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/lib64")
+ ENDIF()
+ math(EXPR i "${i} + 1")
+ endwhile()
+endfunction ()
+
+# build third_party libraries using autotools
+function (third_party_library_builder_autotools arg)
+ list(LENGTH ARGV argv_len)
+ set(i 0)
+ while (i LESS ${argv_len})
+ list(GET ARGV ${i} argv_value)
+ message(STATUS "Start to build third_party library: ${argv_value}")
+ execute_process(COMMAND bash -c "
+ cd ${PROJECT_SOURCE_DIR}/third_party/${argv_value};
+ mkdir -p build;
+ ./configure --prefix=$(pwd)/build CXXFLAGS='-D_GLIBCXX_USE_CXX11_ABI=0' LDFLAGS='-D_GLIBCXX_USE_CXX11_ABI=0';
+ make;
+ make install;
+ cd ${PROJECT_SOURCE_DIR}")
+ if (EXISTS "${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/include")
+ include_directories("${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/include")
+ endif ()
+ if (EXISTS "${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/lib")
+ link_directories("${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/lib")
+ endif ()
+ if (EXISTS "${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/lib64")
+ link_directories("${PROJECT_SOURCE_DIR}/third_party/${argv_value}/build/lib64")
+ endif ()
+ math(EXPR i "${i} + 1")
+ endwhile ()
+endfunction ()
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
+function (compile_protobuf dir)
+ file(GLOB pbs ${PROJECT_SOURCE_DIR}/${dir}/*.proto)
+ execute_process(COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} ${pbs} --cpp_out=${PROJECT_SOURCE_DIR} --proto_path=${PROJECT_SOURCE_DIR})
+endfunction()
-include_directories(.)
+find_package(NumPy REQUIRED)
+find_package(PythonLibs REQUIRED)
link_directories(${PROJECT_BINARY_DIR}/third_party/googletest/)
link_directories(${PROJECT_BINARY_DIR}/third_party/zookeeper-client/)
-link_directories(${PROJECT_BINARY_DIR}/third_party/glog/)
+include_directories("${PYTHON_NUMPY_INCLUDE_DIR}")
+include_directories("${PYTHON_INCLUDE_DIRS}")
-include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/)
include_directories(${PROJECT_SOURCE_DIR}/third_party/hdfs/)
+include_directories(${PROJECT_SOURCE_DIR}/third_party/seastar/ ${PROJECT_SOURCE_DIR}/third_party/seastar/fmt ${PROJECT_SOURCE_DIR}/third_party/seastar/c-ares)
+link_directories("${PROJECT_SOURCE_DIR}/third_party/seastar/lib")
aux_source_directory(ps-plus/service/seastar/lib SEASTAR_LIB)
add_library(seastar_service STATIC ${SEASTAR_LIB})
@@ -35,7 +97,7 @@ if (USE_HDFS OR NOT DEFINED USE_HDFS)
SET(PLUGINS ${PLUGINS} ps_plugin_hdfs)
endif()
-set(LIBRARYS -Wl,--whole-archive ps_common ps_server ps_model_server ps_client ps_scheduler libhashtable.a libzookeeper.a libevent_core.a ${PLUGINS} glog -Wl,--no-whole-archive ${PLUGINS_DEPENDENCY} ${SEASTAR_LIBRARYS})
+set(LIBRARYS -Wl,--whole-archive ps_common ps_server ps_model_server ps_client ps_scheduler libhashtable.a libzookeeper.a libevent_core.a ${PLUGINS} -Wl,--no-whole-archive ${PLUGINS_DEPENDENCY} ${SEASTAR_LIBRARYS} ${PYTHON_LIBRARIES})
aux_source_directory(ps-plus/common COMMON)
aux_source_directory(ps-plus/common/initializer COMMON_INITIALIZER)
@@ -52,7 +114,6 @@ aux_source_directory(ps-plus/scheduler/test SCHEDULER_TEST)
aux_source_directory(ps-plus/main MAIN)
aux_source_directory(ps-plus/model_server MODEL_SERVER)
aux_source_directory(ps-plus/model_server/test MODEL_SERVER_TEST)
-aux_source_directory(ps-plus/tool CLIENT_TOOL)
aux_source_directory(ps-plus/profiler PROFILER)
aux_source_directory(ps-plus/common/test COMMON_TEST)
aux_source_directory(ps-plus/common/initializer/test COMMON_INITIALIZER_TEST)
@@ -72,7 +133,7 @@ add_library(ps_client STATIC ${CLIENT} ${CLIENT_PARTITIONER})
add_library(ps_plugin_hdfs STATIC ${PLUGINS_HDFS})
add_executable(ps ${MAIN})
-add_executable(tool ${CLIENT_TOOL})
+add_executable(tool ps-plus/tool/client_tool.cpp)
# tests
add_executable(ps_common_test ${COMMON_TEST} ${COMMON_INITIALIZER_TEST} ${SRC_TEST_UTIL})
@@ -85,15 +146,15 @@ add_executable(ps_scheduler_test ${SCHEDULER_TEST})
# profiler
add_executable(ps_profiler ${PROFILER})
-target_link_libraries(ps ${LIBRARYS} libjemalloc.a)
-target_link_libraries(tool ${LIBRARYS} libjemalloc.a)
-target_link_libraries(ps_common_test ${LIBRARYS} gtest gtest_main libjemalloc.a)
-target_link_libraries(ps_message_test ${LIBRARYS} gtest gtest_main libjemalloc.a)
-target_link_libraries(ps_model_server_test ${LIBRARYS} gtest gtest_main libjemalloc.a)
-target_link_libraries(ps_server_test ${LIBRARYS} gtest gtest_main libjemalloc.a)
-target_link_libraries(ps_client_test ${LIBRARYS} gtest gtest_main libjemalloc.a)
-target_link_libraries(ps_scheduler_test ${LIBRARYS} gtest gtest_main libjemalloc.a)
-target_link_libraries(ps_profiler ${LIBRARYS} libjemalloc.a)
+target_link_libraries(ps ${LIBRARYS} libjemalloc.a ${TBB_IMPORTED_TARGETS})
+target_link_libraries(tool ${LIBRARYS} libjemalloc.a ${TBB_IMPORTED_TARGETS})
+target_link_libraries(ps_common_test ${LIBRARYS} gtest gtest_main libjemalloc.a ${TBB_IMPORTED_TARGETS})
+target_link_libraries(ps_message_test ${LIBRARYS} gtest gtest_main libjemalloc.a ${TBB_IMPORTED_TARGETS})
+target_link_libraries(ps_model_server_test ${LIBRARYS} gtest gtest_main libjemalloc.a ${TBB_IMPORTED_TARGETS})
+target_link_libraries(ps_server_test ${LIBRARYS} gtest gtest_main libjemalloc.a ${TBB_IMPORTED_TARGETS})
+target_link_libraries(ps_client_test ${LIBRARYS} gtest gtest_main libjemalloc.a ${TBB_IMPORTED_TARGETS})
+target_link_libraries(ps_scheduler_test ${LIBRARYS} gtest gtest_main libjemalloc.a ${TBB_IMPORTED_TARGETS})
+target_link_libraries(ps_profiler ${LIBRARYS} libjemalloc.a ${TBB_IMPORTED_TARGETS})
enable_testing()
add_test(NAME ps_common_test COMMAND ps_common_test)
diff --git a/xdl/ps-plus/README.md b/xdl/ps-plus/README.md
index 8461cf40..14861b37 100644
--- a/xdl/ps-plus/README.md
+++ b/xdl/ps-plus/README.md
@@ -1,3 +1,18 @@
+/* Copyright (C) 2016-2018 Alibaba Group Holding Limited
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
#ps-plus使用
ps-plus是底层的ps组件,可单独介入到其他的业务系统中。简单实用如下:
diff --git a/xdl/ps-plus/ps-plus/client/base_client.h b/xdl/ps-plus/ps-plus/client/base_client.h
index edb96a98..92ab3cd6 100644
--- a/xdl/ps-plus/ps-plus/client/base_client.h
+++ b/xdl/ps-plus/ps-plus/client/base_client.h
@@ -22,9 +22,11 @@ limitations under the License.
#include "ps-plus/common/status.h"
#include "ps-plus/common/tensor.h"
#include "ps-plus/message/variable_info.h"
+#include "ps-plus/message/worker_state.h"
#include "ps-plus/client/udf.h"
#include "ps-plus/client/partitioner.h"
+#include "ps-plus/client/merged_partitioner.h"
namespace ps {
namespace client {
@@ -37,17 +39,35 @@ class BaseClient {
virtual Status Init() = 0;
virtual void Save(const std::string& name, const Callback& cb) = 0;
virtual void Restore(const std::string& name, const Callback& cb) = 0;
- virtual void TriggerStreamingModelDense(const Callback& cb) = 0;
- virtual void TriggerStreamingModelSparse(const Callback& cb) = 0;
- virtual void TriggerStreamingModelHash(const Callback& cb) = 0;
-
+ virtual void TriggerStreamingModelDense(const std::string& stream_ver, const Callback& cb) = 0;
+ virtual void TriggerStreamingModelSparse(const std::string& stream_ver, const Callback& cb) = 0;
+ virtual void TriggerStreamingModelHash(const std::string& stream_ver, const Callback& cb) = 0;
+ virtual Status InitGlobalQueue(
+ const std::string& name,
+ const std::vector& paths,
+ size_t epochs,
+ bool epoch_isolate = false) = 0;
+ virtual Status GetNextFile(
+ const std::string& name,
+ size_t worker_id,
+ std::string* path,
+ size_t* begin,
+ size_t* epoch) = 0;
+ virtual Status ReportWorkerState(
+ const std::string& name,
+ size_t worker_id,
+ const std::vector& worker_states) = 0;
+ virtual Status RestoreWorkerState(
+ const std::string& name,
+ size_t worker_id) = 0;
virtual Status RegisterVariable(const std::string& name, const VariableInfo& info) = 0;
-
virtual void AsynchronizeEnter(int id, int staleness, int worker_count, const Callback& cb) = 0;
virtual void SynchronizeEnter(int id, int worker_count, const Callback& cb) = 0;
virtual void SynchronizeLeave(int id, const Callback& cb) = 0;
virtual void WorkerReportFinish(int id, const Callback& cb) = 0;
+ virtual void GetWorkerFinishCount(int64_t* count, const Callback& cb) = 0;
virtual void WorkerBarrier(int id, int worker_count, const Callback& cb) = 0;
+ virtual void WorkerBarrierV2(int barrier_id, int task_id, int task_num, int token, const Callback& cb) = 0;
virtual void ModelServerForward(int type, const Tensor& ids, Tensor* rst, const Callback& cb) = 0;
virtual void ModelServerBackward(int type, const Tensor& ids, const Tensor& grads, const Callback& cb) = 0;
@@ -58,7 +78,7 @@ class BaseClient {
const Tensor& init,
const Callback& cb) = 0;
virtual void HashInitializer(const std::string& variable_name,
- Initializer* init,
+ Initializer* init,
const Callback& cb) = 0;
virtual void IsInitialized(const std::string& variable_name,
bool* inited,
@@ -80,23 +100,54 @@ class BaseClient {
const std::vector& data,
const Callback& cb) = 0;
virtual void HashPull(const std::string& variable_name,
- const Tensor& ids,
- double add_probability,
+ const Tensor& ids,
+ const float& save_ratio,
Tensor* result,
const Callback& cb) = 0;
+ virtual void MergedHashPull(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ std::vector* result,
+ const Callback& cb) = 0;
virtual void HashPush(const std::string& variable_name,
- const Tensor& ids,
- const std::string& updater,
+ const Tensor& ids,
+ const float& save_ratio,
+ const bool& insertable,
+ const std::string& updater,
const std::vector& data,
const Callback& cb) = 0;
+ virtual void MergedHashPush(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ const std::string& updater,
+ const std::vector& data,
+ const Callback& cb) = 0;
+ virtual void MergedHashStatis(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ const std::vector& clicks,
+ const Tensor& global_step,
+ const Tensor& statis_decay,
+ const Tensor& statis_decay_period,
+ const std::string& statis_type,
+ std::vector* result,
+ const Callback& cb) = 0;
+
+ virtual void Process(const UdfChain& udf,
+ const std::string& var_name,
+ const std::vector& datas,
+ const std::vector& splitter,
+ const std::vector& combiner,
+ std::vector>* results,
+ const Callback& cb) = 0;
virtual void Process(const UdfChain& udf,
- const std::string& var_name,
- const std::vector& datas,
- const std::vector& splitter,
- const std::vector& combiner,
- std::vector>* results,
- const Callback& cb) = 0;
+ const std::vector& var_names,
+ const std::vector& datas,
+ const std::vector& splitter,
+ const std::vector& combiner,
+ std::vector>>* results,
+ const Callback& cb) = 0;
template
std::vector Args(Targs&&... args) {
diff --git a/xdl/ps-plus/ps-plus/client/client.cc b/xdl/ps-plus/ps-plus/client/client.cc
index 736add32..bdcff71a 100644
--- a/xdl/ps-plus/ps-plus/client/client.cc
+++ b/xdl/ps-plus/ps-plus/client/client.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include "ps-plus/client/partitioner/logic.h"
#include "ps-plus/client/partitioner/sparse.h"
#include "ps-plus/client/partitioner/broadcast.h"
+#include "ps-plus/client/partitioner/merged_broadcast.h"
#include "ps-plus/client/partitioner/index.h"
#include "ps-plus/client/partitioner/hash.h"
+#include "ps-plus/client/partitioner/merged_hash.h"
#include
#include
@@ -91,18 +93,26 @@ void Client::IdentityInitializer(const std::string& variable_name,
}
void Client::HashInitializer(const std::string& variable_name,
- Initializer* init,
+ Initializer* init,
const Client::Callback& cb) {
- std::vector inputs = Args(0, 0, std::unique_ptr(init));
+ VariableInfo info;
+ CHECK_ASYNC(GetVariableInfo(variable_name, &info));
+ std::string extra_info;
+ for (auto& arg : info.args) {
+ extra_info += arg.first + "=" + arg.second + "&";
+ }
+ if (!extra_info.empty()) { extra_info.pop_back(); }
+ std::vector inputs = Args(0, 0, extra_info, std::unique_ptr(init));
std::vector>* outputs =
new std::vector>;
std::vector splitter = {
new partitioner::HashDataType,
new partitioner::HashShape,
+ new partitioner::Broadcast,
new partitioner::Broadcast
};
std::vector combiner = {};
- UdfData udf("HashVariableInitializer", UdfData(0), UdfData(1), UdfData(2));
+ UdfData udf("HashVariableInitializer", UdfData(0), UdfData(1), UdfData(2), UdfData(3));
Callback realcb = [cb, outputs](const Status& st) {
std::unique_ptr>> deleter(outputs);
cb(st);
@@ -150,7 +160,7 @@ void Client::DensePull(const std::string& variable_name,
std::vector combiner = { new partitioner::Dense };
UdfData udf("BuildDenseSlice", UdfData(0));
UdfData udf_chain("TransSlice", udf);
- Callback realcb = [this, cb, result, outputs, &variable_name](const Status& st) {
+ Callback realcb = [this, cb, result, outputs](const Status& st) {
std::unique_ptr>> deleter(outputs);
if (!st.IsOk()) {
cb(st);
@@ -217,7 +227,8 @@ void Client::DensePush(const std::string& variable_name,
inputs.insert(inputs.end(), data.begin(), data.end());
for (size_t i = start_index; i < inputs.size(); i++) {
- if (dynamic_cast*>(inputs[i]) != nullptr) {
+ if (dynamic_cast*>(inputs[i]) != nullptr
+ || dynamic_cast>*>(inputs[i]) != nullptr) {
splitter.push_back(new partitioner::Dense);
} else {
splitter.push_back(new partitioner::Broadcast);
@@ -323,7 +334,8 @@ void Client::SparsePush(const std::string& variable_name,
inputs.insert(inputs.end(), data.begin(), data.end());
for (size_t i = start_index; i < inputs.size(); i++) {
- if (dynamic_cast*>(inputs[i]) != nullptr) {
+ if (dynamic_cast*>(inputs[i]) != nullptr
+ || dynamic_cast>*>(inputs[i]) != nullptr) {
splitter.push_back(new partitioner::SparseData);
} else {
splitter.push_back(new partitioner::Broadcast);
@@ -342,22 +354,27 @@ void Client::SparsePush(const std::string& variable_name,
}
void Client::HashPull(const std::string& variable_name,
- const Tensor& ids,
- double add_probability,
- Tensor* result,
+ const Tensor& ids,
+ const float& save_ratio,
+ Tensor* result,
const Client::Callback& cb) {
- std::vector inputs = Args(ids, false, add_probability);
+ std::vector ids_vec = {ids};
+ std::vector name_vec = {variable_name};
+ std::vector save_ratio_vec = {save_ratio};
+ std::vector inputs = Args(ids_vec, name_vec, save_ratio_vec, false, true);
std::vector>* outputs =
new std::vector>;
std::vector splitter = {
new partitioner::HashId,
new partitioner::Broadcast,
- new partitioner::Broadcast
+ new partitioner::Broadcast,
+ new partitioner::Broadcast,
+ new partitioner::Broadcast
};
std::vector combiner = {
new partitioner::HashData
};
- UdfData udf("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2));
+ UdfData udf("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2), UdfData(3), UdfData(4));
UdfData udf_chain("TransSlice", udf);
Callback realcb = [cb, result, outputs](const Status& st) {
std::unique_ptr>> deleter(outputs);
@@ -393,23 +410,81 @@ void Client::HashPull(const std::string& variable_name,
}
}
+void Client::MergedHashPull(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ std::vector* result,
+ const Client::Callback& cb) {
+ std::vector inputs = Args(ids, var_names, save_ratios, false, true);
+ std::vector>>* outputs =
+ new std::vector>>;
+ std::vector splitter = {
+ new partitioner::MergedHashId,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast
+ };
+ std::vector combiner = {
+ new partitioner::MergedHashData
+ };
+ UdfData udf("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2), UdfData(3), UdfData(4));
+ Callback realcb = [cb, result, outputs, var_names](const Status& st) {
+ std::unique_ptr>>> deleter(outputs);
+ if (!st.IsOk()) {
+ cb(st);
+ return;
+ }
+
+ if (outputs->size() != 1) {
+ cb(Status::ArgumentError("Output Size Should be 1 on MergedHashPull"));
+ return;
+ }
+
+ std::vector>& output_vec = (*outputs)[0];
+ if (output_vec.size() != var_names.size()) {
+ cb(Status::ArgumentError("Output[0] Size Should be the Same with Variable Number"));
+ return;
+ }
+ for (auto& output : output_vec) {
+ WrapperData* output_ptr = dynamic_cast*>(output.get());
+ if (output_ptr == nullptr) {
+ cb(Status::ArgumentError("Output[0] should be tensor vector"));
+ return;
+ }
+ (*result).push_back(output_ptr->Internal());
+ }
+ cb(Status::Ok());
+ };
+
+ Process(udf, var_names, inputs, splitter,
+ combiner, outputs, realcb);
+}
+
void Client::HashPush(const std::string& variable_name,
- const Tensor& ids,
- const std::string& updater,
+ const Tensor& ids,
+ const float& save_ratio,
+ const bool& insertable,
+ const std::string& updater,
const std::vector& data,
const Client::Callback& cb) {
- std::vector inputs = Args(ids, true, 0.0);
- size_t start_index = 3;
+ std::vector ids_vec = {ids};
+ std::vector name_vec = {variable_name};
+ std::vector save_ratio_vec = {save_ratio};
+ std::vector inputs = Args(ids_vec, name_vec, save_ratio_vec, true, insertable);
+ size_t start_index = 5;
std::vector>* outputs =
new std::vector>;
std::vector splitter = {
new partitioner::HashId,
- new partitioner::Broadcast,
- new partitioner::Broadcast
+ new partitioner::Broadcast,
+ new partitioner::Broadcast,
+ new partitioner::Broadcast,
+ new partitioner::Broadcast
};
std::vector combiner = {};
std::vector next_udf_inputs = {
- UdfData("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2))
+ UdfData("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2), UdfData(3), UdfData(4))
};
if (sync_mode_ &&
@@ -419,20 +494,21 @@ void Client::HashPush(const std::string& variable_name,
updater != "MovingAverageUpdater") {
inputs.push_back(Args(token_)[0]);
inputs.push_back(Args(worker_count_)[0]);
- next_udf_inputs.push_back(UdfData(3));
- next_udf_inputs.push_back(UdfData(4));
- next_udf_inputs.push_back(UdfData(5));
+ next_udf_inputs.push_back(UdfData(5));
+ next_udf_inputs.push_back(UdfData(6));
+ next_udf_inputs.push_back(UdfData(7));
splitter.push_back(new partitioner::Broadcast);
splitter.push_back(new partitioner::Broadcast);
splitter.push_back(new partitioner::HashData);
UdfData aggregate("AggregateSlice", next_udf_inputs);
next_udf_inputs = {aggregate};
- start_index = 6;
+ start_index = 8;
}
inputs.insert(inputs.end(), data.begin(), data.end());
for (size_t i = start_index; i < inputs.size(); i++) {
- if (dynamic_cast*>(inputs[i]) != nullptr) {
+ if (dynamic_cast*>(inputs[i]) != nullptr
+ || dynamic_cast>*>(inputs[i]) != nullptr) {
splitter.push_back(new partitioner::HashData);
} else {
splitter.push_back(new partitioner::Broadcast);
@@ -450,5 +526,128 @@ void Client::HashPush(const std::string& variable_name,
combiner, outputs, realcb);
}
+void Client::MergedHashPush(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ const std::string& updater,
+ const std::vector& data,
+ const Client::Callback& cb) {
+ std::vector inputs = Args(ids, var_names, save_ratios, true, false);
+ size_t start_index = 5;
+ std::vector>>* outputs =
+ new std::vector>>;
+ std::vector splitter = {
+ new partitioner::MergedHashId,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast
+ };
+ std::vector combiner = {};
+ std::vector next_udf_inputs = {
+ UdfData("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2), UdfData(3), UdfData(4))
+ };
+
+ if (sync_mode_ &&
+ updater != "AssignUpdater" &&
+ updater != "AssignAddUpdater" &&
+ updater != "AssignSubUpdater" &&
+ updater != "MovingAverageUpdater") {
+ inputs.push_back(Args(token_)[0]);
+ inputs.push_back(Args(worker_count_)[0]);
+ next_udf_inputs.push_back(UdfData(5));
+ next_udf_inputs.push_back(UdfData(6));
+ next_udf_inputs.push_back(UdfData(7));
+ splitter.push_back(new partitioner::MergedBroadcast);
+ splitter.push_back(new partitioner::MergedBroadcast);
+ splitter.push_back(new partitioner::MergedHashData);
+ UdfData aggregate("AggregateSlice", next_udf_inputs);
+ next_udf_inputs = {aggregate};
+ start_index = 8;
+ }
+
+ inputs.insert(inputs.end(), data.begin(), data.end());
+ for (size_t i = start_index; i < inputs.size(); i++) {
+ if (dynamic_cast*>(inputs[i]) != nullptr
+ || dynamic_cast>*>(inputs[i]) != nullptr) {
+ splitter.push_back(new partitioner::MergedHashData);
+ } else {
+ splitter.push_back(new partitioner::MergedBroadcast);
+ }
+ next_udf_inputs.push_back(UdfData(i));
+ }
+
+ UdfData udf(updater, next_udf_inputs);
+ Callback realcb = [cb, outputs](const Status& st) {
+ std::unique_ptr>>> deleter(outputs);
+ cb(st);
+ };
+
+ Process(udf, var_names, inputs, splitter,
+ combiner, outputs, realcb);
+}
+
+void Client::MergedHashStatis(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ const std::vector& clicks,
+ const Tensor& global_step,
+ const Tensor& statis_decay,
+ const Tensor& statis_decay_period,
+ const std::string& statis_type,
+ std::vector* result,
+ const Client::Callback& cb) {
+ std::vector inputs = Args(ids, var_names, save_ratios, clicks, global_step, statis_decay, statis_decay_period, statis_type, false, true);
+ std::vector>>* outputs =
+ new std::vector>>;
+ std::vector splitter = {
+ new partitioner::MergedHashId,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedHashData,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast,
+ new partitioner::MergedBroadcast
+ };
+ std::vector combiner = {
+ new partitioner::MergedHashData
+ };
+ UdfData udf("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2), UdfData(8), UdfData(9));
+ UdfData udf_chain("StatisSlice", udf, UdfData(3), UdfData(4), UdfData(5), UdfData(6), UdfData(7));
+ Callback realcb = [cb, result, outputs, var_names](const Status& st) {
+ std::unique_ptr>>> deleter(outputs);
+ if (!st.IsOk()) {
+ cb(st);
+ return;
+ }
+
+ if (outputs->size() != 1) {
+ cb(Status::ArgumentError("Output Size Should be 1 on MergedHashStatis"));
+ return;
+ }
+
+ std::vector>& output_vec = (*outputs)[0];
+ if (output_vec.size() != var_names.size()) {
+ cb(Status::ArgumentError("Output[0] Size Should be the Same with Variable Number"));
+ return;
+ }
+ for (auto& output : output_vec) {
+ WrapperData* output_ptr = dynamic_cast*>(output.get());
+ if (output_ptr == nullptr) {
+ cb(Status::ArgumentError("Output[0] should be tensor vector"));
+ return;
+ }
+ result->push_back(output_ptr->Internal());
+ }
+ cb(Status::Ok());
+ };
+
+ Process(udf_chain, var_names, inputs, splitter,
+ combiner, outputs, realcb);
+}
+
} //namespace client
} //namespace ps
diff --git a/xdl/ps-plus/ps-plus/client/client.h b/xdl/ps-plus/ps-plus/client/client.h
index 1579106d..57660e7f 100644
--- a/xdl/ps-plus/ps-plus/client/client.h
+++ b/xdl/ps-plus/ps-plus/client/client.h
@@ -19,6 +19,7 @@ limitations under the License.
#include
#include
+#include "ps-plus/common/logging.h"
#include "ps-plus/client/raw_client.h"
#include "ps-plus/client/base_client.h"
#include "ps-plus/common/tensor.h"
@@ -44,11 +45,22 @@ class Client: public BaseClient {
const std::vector& datas,
const std::vector& splitter,
const std::vector& combiner,
- std::vector>* results,
+ std::vector >* results,
const Callback& cb) override {
return raw_->Process(udf, var_name, datas, splitter, combiner, results, cb);
}
+ void Process(
+ const UdfChain& udf,
+ const std::vector& var_names,
+ const std::vector& datas,
+ const std::vector& splitter,
+ const std::vector& combiner,
+ std::vector > >* results,
+ const Callback& cb) override {
+ return raw_->Process(udf, var_names, datas, splitter, combiner, results, cb);
+ }
+
void Save(const std::string& name, const Callback& cb) override {
return raw_->Save(name, cb);
}
@@ -57,16 +69,46 @@ class Client: public BaseClient {
return raw_->Restore(name, cb);
}
- void TriggerStreamingModelDense(const Callback& cb) override {
- return raw_->TriggerStreamingModelDense(cb);
+ void TriggerStreamingModelDense(const std::string& stream_ver, const Callback& cb) override {
+ return raw_->TriggerStreamingModelDense(stream_ver, cb);
+ }
+
+ Status InitGlobalQueue(
+ const std::string& name,
+ const std::vector& paths,
+ size_t epochs,
+ bool epoch_isolate = false) override {
+ return raw_->InitGlobalQueue(name, paths, epochs, epoch_isolate);
+ }
+
+ Status GetNextFile(
+ const std::string& name,
+ size_t worker_id,
+ std::string* path,
+ size_t* begin,
+ size_t* epoch) override {
+ return raw_->GetNextFile(name, worker_id, path, begin, epoch);
+ }
+
+ Status ReportWorkerState(
+ const std::string& name,
+ size_t worker_id,
+ const std::vector& worker_states) override {
+ return raw_->ReportWorkerState(name, worker_id, worker_states);
}
- void TriggerStreamingModelSparse(const Callback& cb) override {
- return raw_->TriggerStreamingModelSparse(cb);
+ Status RestoreWorkerState(
+ const std::string& name,
+ size_t worker_id) override {
+ return raw_->RestoreWorkerState(name, worker_id);
}
- void TriggerStreamingModelHash(const Callback& cb) override {
- return raw_->TriggerStreamingModelHash(cb);
+ void TriggerStreamingModelSparse(const std::string& stream_ver, const Callback& cb) override {
+ return raw_->TriggerStreamingModelSparse(stream_ver, cb);
+ }
+
+ void TriggerStreamingModelHash(const std::string& stream_ver, const Callback& cb) override {
+ return raw_->TriggerStreamingModelHash(stream_ver, cb);
}
Status RegisterVariable(const std::string& name, const VariableInfo& info) override {
@@ -91,9 +133,17 @@ class Client: public BaseClient {
raw_->WorkerReportFinish(id, cb);
}
+ void GetWorkerFinishCount(int64_t* count, const Callback& cb) {
+ raw_->GetWorkerFinishCount(count, cb);
+ }
+
void WorkerBarrier(int id, int worker_count, const Callback& cb) override {
raw_->WorkerBarrier(id, worker_count, cb);
- }
+ }
+
+ void WorkerBarrierV2(int barrier_id, int task_id, int task_num, int token, const Callback& cb) override {
+ raw_->WorkerBarrierV2(barrier_id, task_id, task_num, token, cb);
+ }
Status UpdateVariableVisitInfo(const std::string& name, int64_t id_num) {
return raw_->UpdateVariableVisitInfo(name, id_num);
@@ -119,7 +169,7 @@ class Client: public BaseClient {
const Tensor& init,
const Callback& cb) override;
void HashInitializer(const std::string& variable_name,
- Initializer* init,
+ Initializer* init,
const Callback& cb) override;
void IsInitialized(const std::string& variable_name,
bool* inited,
@@ -141,15 +191,43 @@ class Client: public BaseClient {
const std::vector& data,
const Callback& cb) override;
void HashPull(const std::string& variable_name,
- const Tensor& ids,
- double add_probability,
+ const Tensor& ids,
+ const float& save_ratio,
Tensor* result,
const Callback& cb) override;
+ void MergedHashPull(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ std::vector* result,
+ const Callback& cb) override;
void HashPush(const std::string& variable_name,
- const Tensor& ids,
- const std::string& updater,
+ const Tensor& ids,
+ const float& save_ratio,
+ const bool& insertable,
+ const std::string& updater,
const std::vector& data,
const Callback& cb) override;
+ void MergedHashPush(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ const std::string& updater,
+ const std::vector& data,
+ const Callback& cb) override;
+ void MergedHashStatis(const std::vector& var_names,
+ const std::vector& ids,
+ const std::vector& save_ratios,
+ const std::vector& clicks,
+ const Tensor& global_step,
+ const Tensor& statis_decay,
+ const Tensor& statis_decay_period,
+ const std::string& statis_type,
+ std::vector* result,
+ const Callback& cb) override;
+
+ private:
+ Status GetVariableInfo(const std::string& name, VariableInfo* info) {
+ return raw_->GetVariableInfo(name, info);
+ }
private:
std::unique_ptr raw_;
diff --git a/xdl/ps-plus/ps-plus/client/client_wrapper.h b/xdl/ps-plus/ps-plus/client/client_wrapper.h
index f8b50f74..166c1d80 100644
--- a/xdl/ps-plus/ps-plus/client/client_wrapper.h
+++ b/xdl/ps-plus/ps-plus/client/client_wrapper.h
@@ -20,7 +20,9 @@ limitations under the License.
#include "ps-plus/common/status.h"
#include "ps-plus/client/udf.h"
#include "ps-plus/client/partitioner.h"
+#include "ps-plus/client/merged_partitioner.h"
#include "ps-plus/common/tensor.h"
+#include "ps-plus/message/worker_state.h"
#include
#include
@@ -43,16 +45,22 @@ class ClientWrapper {
virtual void RegisterUdf(size_t server_id, const UdfChain& def, const Callback& cb) = 0;
virtual void Save(const std::string& version, const Callback& cb) = 0;
virtual void Restore(const std::string& version, const Callback& cb) = 0;
+ virtual Status InitGlobalQueue(const std::string& name, const std::vector& paths, size_t epochs, bool epoch_isolate = false) = 0;
+ virtual Status GetNextFile(const std::string& name, size_t worker_id, std::string* path, size_t* begin, size_t* epoch) = 0;
+ virtual Status ReportWorkerState(const std::string& name, size_t worker_id, const std::vector& worker_states) = 0;
+ virtual Status RestoreWorkerState(const std::string& name, size_t worker_id) = 0;
virtual void ModelServerForward(int server_type, int server_id, const Tensor& ids, std::unique_ptr* rst, const Callback& cb) = 0;
virtual void ModelServerBackward(int server_type, int server_id, const Tensor& ids, const Tensor& grads, const Callback& cb) = 0;
- virtual void TriggerStreamingModelDense(const Callback& cb) = 0;
- virtual void TriggerStreamingModelSparse(const Callback& cb) = 0;
- virtual void TriggerStreamingModelHash(const Callback& cb) = 0;
+ virtual void TriggerStreamingModelDense(const std::string& stream_ver, const Callback& cb) = 0;
+ virtual void TriggerStreamingModelSparse(const std::string& stream_ver, const Callback& cb) = 0;
+ virtual void TriggerStreamingModelHash(const std::string& stream_ver, const Callback& cb) = 0;
virtual void AsynchronizeEnter(int id, int staleness, int worker_count, const Callback& cb) = 0;
virtual void SynchronizeEnter(int id, int worker_count, int64_t* token, const Callback& cb) = 0;
virtual void SynchronizeLeave(int id, int64_t token, const Callback& cb) = 0;
virtual void WorkerReportFinish(int id, const Callback& cb) = 0;
+ virtual void GetWorkerFinishCount(int64_t* count, const Callback& cb) = 0;
virtual void WorkerBarrier(int id, int worker_count, const Callback& cb) = 0;
+ virtual void WorkerBarrierV2(int barrier_id, int task_id, int task_num, int token, const Callback& cb) = 0;
virtual int ServerSize(int id) = 0;
virtual int ServerTypeSize() = 0;
};
diff --git a/xdl/ps-plus/ps-plus/client/client_wrapper_impl.cc b/xdl/ps-plus/ps-plus/client/client_wrapper_impl.cc
index 51a8ba06..115fed50 100644
--- a/xdl/ps-plus/ps-plus/client/client_wrapper_impl.cc
+++ b/xdl/ps-plus/ps-plus/client/client_wrapper_impl.cc
@@ -21,6 +21,7 @@ limitations under the License.
using ps::service::seastar::CallBackClosure;
using ps::service::seastar::SeastarClientLib;
+using ps::service::seastar::EventClientLib;
using ps::service::seastar::SeastarStatus;
namespace ps {
@@ -227,6 +228,125 @@ void ClientWrapperImpl::Restore(const std::string& version, const Callback& cb)
client_lib_->Request(0, func_ids::kSchedulerRestore, request_datas, cb_closure);
}
+Status ClientWrapperImpl::InitGlobalQueue(
+ const std::string& name,
+ const std::vector& paths,
+ size_t epochs,
+ bool epoch_isolate) {
+ std::vector request_datas = {
+ new WrapperData(scheduler_version_),
+ new WrapperData(name),
+ new WrapperData >(paths),
+ new WrapperData(epochs),
+ new WrapperData(epoch_isolate)
+ };
+
+ std::promise p;
+ CallBackClosure* cb_closure =
+ new CallBackClosure([&p](const SeastarStatus& sst,
+ const std::vector& response) {
+ Status st = GetNetworkStatus(sst, response);
+ p.set_value(st);
+ });
+
+ client_lib_->Request(0, func_ids::kSchedulerInitGlobalFileQueue,
+ request_datas, cb_closure);
+ return p.get_future().get();
+}
+
+Status ClientWrapperImpl::GetNextFile(
+ const std::string& name,
+ size_t worker_id,
+ std::string* path,
+ size_t* begin,
+ size_t* epoch) {
+ std::vector request_datas = {
+ new WrapperData(scheduler_version_),
+ new WrapperData(name),
+ new WrapperData(worker_id)
+ };
+
+ std::promise p;
+ CallBackClosure* cb_closure =
+ new CallBackClosure([&p, path, begin, epoch, this](
+ const SeastarStatus& sst,
+ const std::vector& response) {
+ Status st = GetNetworkStatus(sst, response);
+ if (!st.IsOk()) {
+ p.set_value(st);
+ return;
+ }
+
+ if (response.size() != 4) {
+ p.set_value(Status::Unknown("response data not match"));
+ return;
+ }
+
+ WrapperData* path_data = dynamic_cast* >(response[1]);
+ WrapperData* begin_data = dynamic_cast* >(response[2]);
+ WrapperData* epoch_data = dynamic_cast* >(response[3]);
+ if (path_data == nullptr || begin_data == nullptr || epoch_data == nullptr) {
+ p.set_value(Status::Unknown("reponse data type not match"));
+ return;
+ }
+
+ *path = path_data->Internal();
+ *begin = begin_data->Internal();
+ *epoch = epoch_data->Internal();
+ p.set_value(Status::Ok());
+ });
+
+ client_lib_->Request(0, func_ids::kSchedulerGetNextFile,
+ request_datas, cb_closure);
+ return p.get_future().get();
+}
+
+Status ClientWrapperImpl::ReportWorkerState(
+ const std::string& name,
+ size_t worker_id,
+ const std::vector& worker_states) {
+ std::vector request_datas = {
+ new WrapperData(scheduler_version_),
+ new WrapperData(name),
+ new WrapperData(worker_id),
+ new WrapperData >(worker_states)
+ };
+
+ std::promise p;
+ CallBackClosure* cb_closure =
+ new CallBackClosure([&p](const SeastarStatus& sst,
+ const std::vector& response) {
+ Status st = GetNetworkStatus(sst, response);
+ p.set_value(st);
+ });
+
+ client_lib_->Request(0, func_ids::kSchedulerReportWorkerState,
+ request_datas, cb_closure);
+ return p.get_future().get();
+}
+
+Status ClientWrapperImpl::RestoreWorkerState(
+ const std::string& name,
+ size_t worker_id) {
+ std::vector request_datas = {
+ new WrapperData(scheduler_version_),
+ new WrapperData(name),
+ new WrapperData(worker_id)
+ };
+
+ std::promise p;
+ CallBackClosure* cb_closure =
+ new CallBackClosure([&p](const SeastarStatus& sst,
+ const std::vector& response) {
+ Status st = GetNetworkStatus(sst, response);
+ p.set_value(st);
+ });
+
+ client_lib_->Request(0, func_ids::kSchedulerRestoreWorkerState,
+ request_datas, cb_closure);
+ return p.get_future().get();
+}
+
void ClientWrapperImpl::ModelServerForward(int server_type, int server_id, const Tensor& ids, std::unique_ptr* rst, const Callback& cb) {
std::vector request_datas = {
new WrapperData(scheduler_version_),
@@ -268,9 +388,10 @@ void ClientWrapperImpl::ModelServerBackward(int server_type, int server_id, cons
request_datas, cb_closure);
}
-void ClientWrapperImpl::TriggerStreamingModelDense(const Callback& cb) {
+void ClientWrapperImpl::TriggerStreamingModelDense(const std::string& stream_ver, const Callback& cb) {
std::vector request_datas = {
- new WrapperData(scheduler_version_)
+ new WrapperData(scheduler_version_),
+ new WrapperData(stream_ver)
};
CallBackClosure* cb_closure = new CallBackClosure([cb](const SeastarStatus& sst, const std::vector& response) {
@@ -280,9 +401,10 @@ void ClientWrapperImpl::TriggerStreamingModelDense(const Callback& cb) {
client_lib_->Request(0, func_ids::kSchedulerTriggerStreamingDense, request_datas, cb_closure);
}
-void ClientWrapperImpl::TriggerStreamingModelSparse(const Callback& cb) {
+void ClientWrapperImpl::TriggerStreamingModelSparse(const std::string& stream_ver, const Callback& cb) {
std::vector request_datas = {
- new WrapperData(scheduler_version_)
+ new WrapperData(scheduler_version_),
+ new WrapperData(stream_ver)
};
CallBackClosure* cb_closure = new CallBackClosure([cb](const SeastarStatus& sst, const std::vector& response) {
@@ -292,9 +414,10 @@ void ClientWrapperImpl::TriggerStreamingModelSparse(const Callback& cb) {
client_lib_->Request(0, func_ids::kSchedulerTriggerStreamingSparse, request_datas, cb_closure);
}
-void ClientWrapperImpl::TriggerStreamingModelHash(const Callback& cb) {
+void ClientWrapperImpl::TriggerStreamingModelHash(const std::string& stream_ver, const Callback& cb) {
std::vector request_datas = {
- new WrapperData(scheduler_version_)
+ new WrapperData(scheduler_version_),
+ new WrapperData(stream_ver)
};
CallBackClosure* cb_closure = new CallBackClosure([cb](const SeastarStatus& sst, const std::vector& response) {
@@ -376,6 +499,28 @@ void ClientWrapperImpl::WorkerReportFinish(int id, const Callback& cb) {
request_datas, cb_closure);
}
+void ClientWrapperImpl::GetWorkerFinishCount(int64_t* count, const Callback& cb) {
+ std::vector request_datas = {
+ new WrapperData(scheduler_version_)
+ };
+ CallBackClosure* cb_closure =
+ new CallBackClosure([count, cb](const SeastarStatus& sst,
+ const std::vector& response) {
+ Status st = GetNetworkStatus(sst, response);
+ if (!st.IsOk()) {
+ cb(st);
+ return;
+ }
+ if (count) {
+ WrapperData* res = dynamic_cast*>(response[1]);
+ *count = res->Internal();
+ }
+ cb(Status::Ok());
+ });
+ client_lib_->Request(0, func_ids::kSchedulerGetWorkerFinishCount,
+ request_datas, cb_closure);
+}
+
void ClientWrapperImpl::WorkerBarrier(int id, int worker_count, const Callback& cb) {
std::vector request_datas = {
new WrapperData(scheduler_version_),
@@ -391,10 +536,32 @@ void ClientWrapperImpl::WorkerBarrier(int id, int worker_count, const Callback&
request_datas, cb_closure);
}
+void ClientWrapperImpl::WorkerBarrierV2(
+ int barrier_id,
+ int task_id,
+ int task_num,
+ int token,
+ const Callback& cb) {
+ std::vector request_datas = {
+ new WrapperData(scheduler_version_),
+ new WrapperData(barrier_id),
+ new WrapperData(task_id),
+ new WrapperData(task_num),
+ new WrapperData(token)
+ };
+ CallBackClosure* cb_closure =
+ new CallBackClosure([cb](const SeastarStatus& sst,
+ const std::vector& response) {
+ cb(GetNetworkStatus(sst, response));
+ });
+ client_lib_->Request(0, func_ids::kSchedulerWorkerBarrierV2,
+ request_datas, cb_closure);
+}
+
Status ClientWrapperImpl::CreateServerLib() {
if (client_lib_singleton_ == nullptr) {
std::vector> server_addrs = {};
- client_lib_singleton_ = new ClientLib(server_addrs, 100, 1);
+ client_lib_singleton_ = new ClientLib(server_addrs, 100, std::thread::hardware_concurrency());
client_lib_ = client_lib_singleton_;
client_lib_->Start();
} else {
diff --git a/xdl/ps-plus/ps-plus/client/client_wrapper_impl.h b/xdl/ps-plus/ps-plus/client/client_wrapper_impl.h
index 7ee8cecd..fbe9251b 100644
--- a/xdl/ps-plus/ps-plus/client/client_wrapper_impl.h
+++ b/xdl/ps-plus/ps-plus/client/client_wrapper_impl.h
@@ -37,19 +37,26 @@ class ClientWrapperImpl : public ClientWrapper {
void RegisterUdf(size_t server_id, const UdfChain& def, const Callback& cb) override;
void Save(const std::string& version, const Callback& cb) override;
void Restore(const std::string& version, const Callback& cb) override;
+ Status InitGlobalQueue(const std::string& name, const std::vector& paths, size_t epochs, bool epoch_isolate = false) override;
+ Status GetNextFile(const std::string& name, size_t worker_id, std::string* path, size_t* begin, size_t* epoch) override;
+ Status ReportWorkerState(const std::string& name, size_t worker_id, const std::vector& worker_states) override;
+ Status RestoreWorkerState(const std::string& name, size_t worker_id) override;
void ModelServerForward(int server_type, int server_id, const Tensor& ids, std::unique_ptr* rst, const Callback& cb) override;
void ModelServerBackward(int server_type, int server_id, const Tensor& ids, const Tensor& grads, const Callback& cb) override;
- void TriggerStreamingModelDense(const Callback& cb) override;
- void TriggerStreamingModelSparse(const Callback& cb) override;
- void TriggerStreamingModelHash(const Callback& cb) override;
+ void TriggerStreamingModelDense(const std::string& stream_ver, const Callback& cb) override;
+ void TriggerStreamingModelSparse(const std::string& stream_ver, const Callback& cb) override;
+ void TriggerStreamingModelHash(const std::string& stream_ver, const Callback& cb) override;
void AsynchronizeEnter(int id, int staleness, int worker_count, const Callback& cb) override;
void SynchronizeEnter(int id, int worker_count, int64_t* token, const Callback& cb) override;
void SynchronizeLeave(int id, int64_t token, const Callback& cb) override;
void WorkerReportFinish(int id, const Callback& cb) override;
+ void GetWorkerFinishCount(int64_t* count, const Callback& cb);
void WorkerBarrier(int id, int worker_count, const Callback& cb) override;
+ void WorkerBarrierV2(int barrier_id, int task_id, int task_num, int token, const Callback& cb) override;
int ServerSize(int id) override;
int ServerTypeSize() override;
+ //using ClientLib = ps::service::seastar::SeastarClientLib;
using ClientLib = ps::service::seastar::EventClientLib;
private:
Status CreateServerLib();
diff --git a/xdl/ps-plus/ps-plus/client/local_client.cc b/xdl/ps-plus/ps-plus/client/local_client.cc
index 28b8d16e..a1778c51 100644
--- a/xdl/ps-plus/ps-plus/client/local_client.cc
+++ b/xdl/ps-plus/ps-plus/client/local_client.cc
@@ -87,20 +87,26 @@ void LocalClient::IdentityInitializer(const std::string& variable_name,
}
void LocalClient::HashInitializer(const std::string& variable_name,
- Initializer* init,
+ Initializer* init,
const LocalClient::Callback& cb) {
VariableInfo info;
CHECK_ASYNC(local_server_->GetVariableInfo(variable_name, &info));
std::vector dims(info.shape.begin(), info.shape.end());
size_t k = info.shape[0];
dims[0] = k + 10 * sqrt(k) + 10;
+ std::string extra_info;
+ for (auto& arg : info.args) {
+ extra_info += arg.first + "=" + arg.second + "&";
+ }
+ if (!extra_info.empty()) { extra_info.pop_back(); }
std::vector inputs = Args(
info.datatype,
TensorShape(dims),
+ extra_info,
std::unique_ptr(init));
std::vector>* outputs =
new std::vector>;
- UdfData udf("HashVariableInitializer", UdfData(0), UdfData(1), UdfData(2));
+ UdfData udf("HashVariableInitializer", UdfData(0), UdfData(1), UdfData(2), UdfData(3));
Callback realcb = [cb, outputs](const Status& st) {
std::unique_ptr>> deleter(outputs);
cb(st);
@@ -143,8 +149,8 @@ void LocalClient::DensePull(const std::string& variable_name,
std::vector>* outputs =
new std::vector>;
UdfData udf("BuildDenseSlice", UdfData(0));
- UdfData udf_chain("SliceToTensor", UdfData("TransSlice", udf));
- Callback realcb = [this, cb, result, outputs, &variable_name](const Status& st) {
+ UdfData udf_chain("SliceToTensor", udf);
+ Callback realcb = [this, cb, result, outputs](const Status& st) {
std::unique_ptr>> deleter(outputs);
if (!st.IsOk()) {
cb(st);
@@ -156,13 +162,19 @@ void LocalClient::DensePull(const std::string& variable_name,
return;
}
- WrapperData* output_ptr = dynamic_cast*>((*outputs)[0].get());
+ WrapperData>* output_ptr =
+ dynamic_cast>*>((*outputs)[0].get());
if (output_ptr == nullptr) {
- cb(Status::ArgumentError("Output[0] should be tensor"));
+ cb(Status::ArgumentError("Output[0] should be tensor vector"));
return;
}
- *result = output_ptr->Internal();
+ if (output_ptr->Internal().size() != 1) {
+ cb(Status::ArgumentError("Output[0] size should be 1"));
+ return;
+ }
+
+ *result = output_ptr->Internal()[0];
cb(Status::Ok());
};
@@ -199,8 +211,8 @@ void LocalClient::SparsePull(const std::string& variable_name,
std::vector>* outputs =
new std::vector>;
UdfData udf("BuildSparseSlice", UdfData(0), UdfData(1));
- UdfData udf_chain("SliceToTensor", UdfData("TransSlice", udf));
- Callback realcb = [cb, result, outputs](const Status& st) {
+ UdfData udf_chain("SliceToTensor", udf);
+ Callback realcb = [this, cb, result, outputs](const Status& st) {
std::unique_ptr>> deleter(outputs);
if (!st.IsOk()) {
cb(st);
@@ -212,14 +224,19 @@ void LocalClient::SparsePull(const std::string& variable_name,
return;
}
- WrapperData* output_ptr =
- dynamic_cast*>((*outputs)[0].get());
+ WrapperData>* output_ptr =
+ dynamic_cast>*>((*outputs)[0].get());
if (output_ptr == nullptr) {
- cb(Status::ArgumentError("Output[0] should be tensor"));
+ cb(Status::ArgumentError("Output[0] should be tensor vector"));
return;
}
- *result = output_ptr->Internal();
+ if (output_ptr->Internal().size() != 1) {
+ cb(Status::ArgumentError("Output[0] size should be 1"));
+ return;
+ }
+
+ *result = output_ptr->Internal()[0];
cb(Status::Ok());
};
@@ -252,16 +269,19 @@ void LocalClient::SparsePush(const std::string& variable_name,
}
void LocalClient::HashPull(const std::string& variable_name,
- const Tensor& ids,
- double filter_ratio,
+ const Tensor& ids,
+ const float& save_ratio,
Tensor* result,
const LocalClient::Callback& cb) {
- std::vector inputs = Args(ids, false, 1.0);
+ std::vector ids_vec = {ids};
+ std::vector name_vec = {variable_name};
+ std::vector save_ratio_vec = {save_ratio};
+ std::vector inputs = Args(ids_vec, name_vec, save_ratio_vec, false, true);
std::vector>* outputs =
new std::vector>;
- UdfData udf("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2));
- UdfData udf_chain("SliceToTensor", UdfData("TransSlice", udf));
- Callback realcb = [cb, result, outputs](const Status& st) {
+ UdfData udf("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2), UdfData(3), UdfData(4));
+ UdfData udf_chain("SliceToTensor", udf);
+ Callback realcb = [this, cb, result, outputs](const Status& st) {
std::unique_ptr>> deleter(outputs);
if (!st.IsOk()) {
cb(st);
@@ -273,33 +293,84 @@ void LocalClient::HashPull(const std::string& variable_name,
return;
}
- WrapperData* output_ptr =
- dynamic_cast*>((*outputs)[0].get());
+ WrapperData>* output_ptr =
+ dynamic_cast>*>((*outputs)[0].get());
if (output_ptr == nullptr) {
- cb(Status::ArgumentError("Output[0] should be tensor"));
+ cb(Status::ArgumentError("Output[0] should be tensor vector"));
return;
}
- *result = output_ptr->Internal();
+ if (output_ptr->Internal().size() != 1) {
+ cb(Status::ArgumentError("Output[0] size should be 1"));
+ return;
+ }
+
+ *result = output_ptr->Internal()[0];
cb(Status::Ok());
};
Process(udf_chain, variable_name, inputs, outputs, realcb);
}
+void LocalClient::MergedHashPull(const std::vector& var_names,
+ const std::vector