diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 4b9c795..85714ca 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.1.3.post2 +current_version = 0.2.0.rc commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d*))? diff --git a/core/BUILD b/core/BUILD index fe1153d..e4400f8 100644 --- a/core/BUILD +++ b/core/BUILD @@ -51,6 +51,30 @@ tf_gen_op_wrapper_py( cc_linkopts = ['-lrt'], ) +cc_library( + name = "bn_table_ops_kernels", + srcs = [ + "kernels/bn_table_ops_dummy.cc", + "ops/bn_table_ops.cc", + ], + hdrs = [ + "//core/utility:semaphore", + ], + linkstatic = 1, + deps = [ + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:lib", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], + alwayslink = 0, +) + +tf_gen_op_wrapper_py( + name = "bn_table_ops", + deps = [":bn_table_ops_kernels"], + cc_linkopts = ['-lrt', '-lssl'] +) + cc_library( name = "balance_dataset_ops_kernels", srcs = [ @@ -127,6 +151,7 @@ cc_binary( "kernels/dense_table_ops.cc", "kernels/data/balance_dataset_ops.cc", "kernels/data/balance_dataset_ops.h", + "kernels/bn_table_ops.cc", "public/version.h", "kernels/resource_var_wrapper.h", "//core/utility:semaphore", diff --git a/core/kernels/bn_table_ops.cc b/core/kernels/bn_table_ops.cc new file mode 100644 index 0000000..a83aca8 --- /dev/null +++ b/core/kernels/bn_table_ops.cc @@ -0,0 +1,309 @@ +// Copyright (c) 2020, Qihoo, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/utility/semaphore.h" +#include "core/ps/table/bn_table.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" + +#include "core/kernels/resource_var_wrapper.h" +#include "core/ps_interface/ps_raw_interface.h" + + +#include +#include +#include +#include +#include + +#include "core/ps/ps_server_interface.h" +#include "core/ps/ps_cluster.h" + +using namespace tensornet; + +namespace tensorflow { + +static void NoOpDeleter(void *) {} + +template +Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); + +const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); + +class BnStatisticsPushCall { +public: + BnStatisticsPushCall(int table_handle, int shard_id) + : shard_id_(shard_id) { + req.set_req_shard_id(shard_id); + req.set_table_handle(table_handle); + } + + ~BnStatisticsPushCall() {} + + void AddRequestData(butil::IOBuf& k_buf) { + butil::IOBuf &buf = cntl.request_attachment(); + buf.append(k_buf); + } + + void Start(const tensornet::Callback& done) { + const PsServerInterface* si = + PsCluster::Instance()->GetServer(shard_id_); + si->BnStatisticsPushAsync(&cntl, &req, &resp, done); + } + +public: + brpc::Controller cntl; + BnStatisticsPushRequest req; + BnStatisticsPushResponse resp; + +private: + int shard_id_ = -1; +}; + + +class BnStatisticsPushKernel : public AsyncOpKernel { +public: + explicit BnStatisticsPushKernel(OpKernelConstruction* c) + : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("table_handle", &table_handle_)); + OP_REQUIRES_OK(c, c->GetAttr("N", &N_)); + OP_REQUIRES_OK(c, c->GetAttr("synchronized", &synchronized_)); + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + butil::IOBuf acc_buf; + + std::vector allocated_pointers; + + for (int i = 0; i < N_; i++) { + const ResourceHandle& handle = HandleFromInput(c, i); + + Var* variable = nullptr; + const auto status = LookupResource(c, handle, &variable); + + OP_REQUIRES_OK_ASYNC(c, status, done); + CHECK(variable); + + Tensor *var_tensor = variable->tensor(); + + int num_elements = var_tensor->NumElements(); + double* dynamic_double_data = new double[num_elements]; + const float* float_data = var_tensor->flat().data(); + for (int i = 0; i < num_elements; ++i) { + dynamic_double_data[i] = static_cast(float_data[i]); + } + acc_buf.append_user_data(dynamic_double_data, num_elements * sizeof(double), NoOpDeleter); + allocated_pointers.push_back(dynamic_double_data); + } + + BnTable* table = BnTableRegistry::Instance()->Get(table_handle_); + table->Append(acc_buf, true); + + for (auto ptr : allocated_pointers) { + delete[] ptr; + } + allocated_pointers.clear(); + + if(synchronized_){ + PsCluster* cluster = PsCluster::Instance(); + OP_REQUIRES_ASYNC( c, true == cluster->IsInitialized(), + errors::InvalidArgument("cluster instance not initialized:"), done); + + butil::IOBuf inc_buf; + table->GetIncStatistics(inc_buf); + + std::vector calls; + + for (size_t shard_id = 0; shard_id < cluster->RankNum(); shard_id++) { + if(shard_id != cluster->Rank()){ + auto* call = new BnStatisticsPushCall(table_handle_, shard_id); + call->AddRequestData(inc_buf); + calls.emplace_back(call); + } + } + + Semaphore semaphore(calls.size()); + + for (auto& call : calls) { + call->Start([this, call, &semaphore]() { + semaphore.Notify(); + delete call; + }); + } + + semaphore.WaitForSemaphore(); + } + + done(); + + return; + } + +private: + int table_handle_; + int N_; + bool synchronized_; +}; + +REGISTER_KERNEL_BUILDER(Name("BnStatisticsPush").Device(DEVICE_CPU), + BnStatisticsPushKernel); + +class UpdateMomentsKernel : public OpKernel { +public: + explicit UpdateMomentsKernel(OpKernelConstruction* c) + : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("table_handle", &table_handle_)); + OP_REQUIRES_OK(c, c->GetAttr("N", &N_)); + } + + void Compute(OpKernelContext* c) override { + std::vector bn_vars; + + for (int i = 0; i < N_; i++) { + const ResourceHandle &handle = HandleFromInput(c, i); + + Var *variable = nullptr; + const auto status = LookupResource(c, handle, &variable); + + OP_REQUIRES_OK(c, status); + CHECK(variable); + bn_vars.emplace_back(variable); + } + + BnTable* table = BnTableRegistry::Instance()->Get(table_handle_); + + std::tuple moments_tuple = table->GetMoments(); + + auto& global_mean_var = bn_vars[0]; + float* global_mean_flat = global_mean_var->tensor()->flat().data(); + std::copy(std::get<0>(moments_tuple).data(), std::get<0>(moments_tuple).data() + std::get<0>(moments_tuple).size(), global_mean_flat); + + auto& global_var_var = bn_vars[1]; + float* global_var_flat = global_var_var->tensor()->flat().data(); + std::copy(std::get<1>(moments_tuple).data(), std::get<1>(moments_tuple).data() + std::get<1>(moments_tuple).size(), global_var_flat); + + return; + } + +private: + int table_handle_; + int N_; +}; + + +REGISTER_KERNEL_BUILDER(Name("UpdateMoments").Device(DEVICE_CPU), + UpdateMomentsKernel); + +class BnStatisticsPullCall { +public: + BnStatisticsPullCall(int table_handle, int shard_id) + : shard_id_(shard_id) { + req.set_req_shard_id(shard_id); + req.set_table_handle(table_handle); + } + + ~BnStatisticsPullCall() {} + + void Start(const tensornet::Callback& done) { + const PsServerInterface* si = + PsCluster::Instance()->GetServer(shard_id_); + si->BnStatisticsPullAsync(&cntl, &req, &resp, done); + } + +public: + brpc::Controller cntl; + BnStatisticsPullRequest req; + BnStatisticsPullResponse resp; + +private: + int shard_id_ = -1; +}; + + +class BnStatisticsPullKernel : public AsyncOpKernel { +public: + explicit BnStatisticsPullKernel(OpKernelConstruction* c) + : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("table_handle", &table_handle_)); + OP_REQUIRES_OK(c, c->GetAttr("N", &N_)); + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + + std::vector bn_vars; + + for (int i = 0; i < N_; i++) { + const ResourceHandle &handle = HandleFromInput(c, i); + + Var *variable = nullptr; + const auto status = LookupResource(c, handle, &variable); + + OP_REQUIRES_OK(c, status); + CHECK(variable); + bn_vars.emplace_back(variable); + } + + PsCluster* cluster = PsCluster::Instance(); + OP_REQUIRES_ASYNC( + c, true == cluster->IsInitialized(), + errors::InvalidArgument("cluster instance not initialized:"), done); + + BnTable *table = BnTableRegistry::Instance()->Get(table_handle_); + std::vector calls; + + for (size_t shard_id = 0; shard_id < cluster->RankNum(); shard_id++) { + if(shard_id != cluster->Rank()){ + calls.emplace_back( + new BnStatisticsPullCall(table_handle_, shard_id)); + } + } + + Semaphore semaphore(calls.size()); + + for (auto& call : calls) { + call->Start([this, call, &table, &semaphore]() { + table->Append(call->cntl.response_attachment(), false); + semaphore.Notify(); + delete call; + }); + } + + semaphore.WaitForSemaphore(); + std::tuple moments_tuple = table->GetMoments(); + + auto& global_mean_var = bn_vars[0]; + float* global_mean_flat = global_mean_var->tensor()->flat().data(); + std::copy(std::get<0>(moments_tuple).data(), std::get<0>(moments_tuple).data() + std::get<0>(moments_tuple).size(), global_mean_flat); + + auto& global_var_var = bn_vars[1]; + float* global_var_flat = global_var_var->tensor()->flat().data(); + std::copy(std::get<1>(moments_tuple).data(), std::get<1>(moments_tuple).data() + std::get<1>(moments_tuple).size(), global_var_flat); + + done(); + + return; + } + +private: + int table_handle_; + int N_; +}; + +REGISTER_KERNEL_BUILDER(Name("BnStatisticsPull").Device(DEVICE_CPU), + BnStatisticsPullKernel); + +}; diff --git a/core/kernels/bn_table_ops_dummy.cc b/core/kernels/bn_table_ops_dummy.cc new file mode 100644 index 0000000..da4589e --- /dev/null +++ b/core/kernels/bn_table_ops_dummy.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2020, Qihoo, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class UpdateMomentsKernel : public OpKernel { +public: + explicit UpdateMomentsKernel(OpKernelConstruction* c) + : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("table_handle", &table_handle_)); + } + + void Compute(OpKernelContext* c) override { + return; + } + +private: + int table_handle_; +}; + +REGISTER_KERNEL_BUILDER(Name("UpdateMoments").Device(DEVICE_CPU), + UpdateMomentsKernel); + +class BnStatisticsPushKernel : public AsyncOpKernel { +public: + explicit BnStatisticsPushKernel(OpKernelConstruction* c) + : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("table_handle", &table_handle_)); + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + done(); + return; + } + +private: + int table_handle_; +}; + +REGISTER_KERNEL_BUILDER(Name("BnStatisticsPush").Device(DEVICE_CPU), + BnStatisticsPushKernel); + +} // namespace tensorflow diff --git a/core/main/py_wrapper.cc b/core/main/py_wrapper.cc index 7d6ba29..fe9730d 100644 --- a/core/main/py_wrapper.cc +++ b/core/main/py_wrapper.cc @@ -17,6 +17,7 @@ #include "core/ps/optimizer/optimizer.h" #include "core/ps/table/dense_table.h" #include "core/ps/table/sparse_table.h" +#include "core/ps/table/bn_table.h" #include "core/kernels/data/balance_dataset_ops.h" #include @@ -113,10 +114,14 @@ PYBIND11_MODULE(_pywrap_tn, m) { return py::reinterpret_steal(obj); }) - .def("create_sparse_table", [](py::object obj, std::string name, int dimension) { + .def("create_sparse_table", [](py::object obj, std::string name, int dimension, bool use_cvm) { OptimizerBase* opt = static_cast(PyCapsule_GetPointer(obj.ptr(), nullptr)); + opt->SetUseCvm(use_cvm); + + std::cout << "Cvm plugin is: " << opt->ShouldUseCvm() << std::endl; + PsCluster* cluster = PsCluster::Instance(); SparseTable* table = CreateSparseTable(opt, name, dimension, cluster->RankNum(), cluster->Rank()); @@ -133,6 +138,21 @@ PYBIND11_MODULE(_pywrap_tn, m) { return table->GetHandle(); }) + .def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, uint64_t max_count) { + PsCluster* cluster = PsCluster::Instance(); + + BnTable* table = CreateBnTable(name, cluster->RankNum(), cluster->Rank(), bn_size, sync, moment, max_count); + + return table->GetHandle(); + }) + .def("save_bn_table", [](uint32_t table_handle, std::string filepath) { + BnTable* table = BnTableRegistry::Instance()->Get(table_handle); + return table->Save(filepath); + }) + .def("load_bn_table", [](uint32_t table_handle, std::string filepath) { + BnTable* table = BnTableRegistry::Instance()->Get(table_handle); + return table->Load(filepath); + }) .def("save_sparse_table", [](uint32_t table_handle, std::string filepath, const std::string& mode="txt") { SparseTable* table = SparseTableRegistry::Instance()->Get(table_handle); diff --git a/core/ops/bn_table_ops.cc b/core/ops/bn_table_ops.cc new file mode 100644 index 0000000..34da9f4 --- /dev/null +++ b/core/ops/bn_table_ops.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2020, Qihoo, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; + +REGISTER_OP("UpdateMoments") + .Doc(R"doc(pull mean, var, count from parameter server)doc") + .Input("vars: N * resource") + .Attr("table_handle: int") + .Attr("N: int") + .SetShapeFn(shape_inference::NoOutputs); + + +REGISTER_OP("BnStatisticsPush") + .Doc(R"doc(save local bn vars to ps)doc") + .Input("vars: N * resource") + .Attr("table_handle: int") + .Attr("N: int") + .Attr("synchronized: bool") + .SetShapeFn(shape_inference::NoOutputs); + + +REGISTER_OP("BnStatisticsPull") + .Doc(R"doc(save local bn vars to ps)doc") + .Input("vars: N * resource") + .Attr("table_handle: int") + .Attr("N: int") + .SetShapeFn(shape_inference::NoOutputs); diff --git a/core/ps/ps_local_server.cc b/core/ps/ps_local_server.cc index 94c68ed..5302623 100644 --- a/core/ps/ps_local_server.cc +++ b/core/ps/ps_local_server.cc @@ -17,6 +17,7 @@ #include "core/ps_interface/ps_server.pb.h" #include "core/ps/table/dense_table.h" #include "core/ps/table/sparse_table.h" +#include "core/ps/table/bn_table.h" #include "core/kernels/data/balance_dataset_ops.h" #include "core/ps/optimizer/optimizer_kernel.h" @@ -85,4 +86,28 @@ void PsLocalServer::DatasetPullAsync(brpc::Controller *cntl, done(); } +void PsLocalServer::BnStatisticsPushAsync(brpc::Controller *cntl, + const BnStatisticsPushRequest *request, + BnStatisticsPushResponse *response, + Callback done) const { + BnTable *table = BnTableRegistry::Instance()->Get(request->table_handle()); + CHECK(nullptr != table); + butil::IOBuf& acc_data = cntl->request_attachment(); + table->Append(acc_data, false); + + done(); +} + +void PsLocalServer::BnStatisticsPullAsync(brpc::Controller *cntl, + const BnStatisticsPullRequest *request, + BnStatisticsPullResponse *response, + Callback done) const { + BnTable *table = BnTableRegistry::Instance()->Get(request->table_handle()); + CHECK(nullptr != table); + response->set_table_handle(request->table_handle()); + butil::IOBuf& bn_statistics_buf = cntl->response_attachment(); + table->GetIncStatistics(bn_statistics_buf); + + done(); +} } // namespace tensornet diff --git a/core/ps/ps_local_server.h b/core/ps/ps_local_server.h index 6dbd864..d07f3e4 100644 --- a/core/ps/ps_local_server.h +++ b/core/ps/ps_local_server.h @@ -44,6 +44,16 @@ class PsLocalServer : public PsServerInterface { const DatasetPullRequest *request, DatasetPullResponse *response, Callback done) const override; + + virtual void BnStatisticsPushAsync(brpc::Controller *cntl, + const BnStatisticsPushRequest *request, + BnStatisticsPushResponse *response, + Callback done) const override; + + virtual void BnStatisticsPullAsync(brpc::Controller *cntl, + const BnStatisticsPullRequest *request, + BnStatisticsPullResponse *response, + Callback done) const override; }; } // namespace tensornet diff --git a/core/ps/ps_remote_server.cc b/core/ps/ps_remote_server.cc index 4b561d5..2b15b64 100644 --- a/core/ps/ps_remote_server.cc +++ b/core/ps/ps_remote_server.cc @@ -108,6 +108,8 @@ PsRemoteServer::PsRemoteServer(std::shared_ptr &channel) sparse_push_dp_ = PsService::descriptor()->FindMethodByName("SparsePush"); dense_push_pull_dp_ = PsService::descriptor()->FindMethodByName("DensePushPull"); dataset_pull_dp_ = PsService::descriptor()->FindMethodByName("DatasetPull"); + bn_statistics_push_dp_ = PsService::descriptor()->FindMethodByName("BnStatisticsPush"); + bn_statistics_pull_dp_ = PsService::descriptor()->FindMethodByName("BnStatisticsPull"); } PsRemoteServer::~PsRemoteServer() {} @@ -144,4 +146,20 @@ void PsRemoteServer::DatasetPullAsync(brpc::Controller *cntl, channel_, cntl, request, response, std::move(done)); } +void PsRemoteServer::BnStatisticsPushAsync(brpc::Controller *cntl, + const BnStatisticsPushRequest *request, + BnStatisticsPushResponse *response, + Callback done) const { + new Call(bn_statistics_push_dp_, + channel_, cntl, request, response, std::move(done)); +} + +void PsRemoteServer::BnStatisticsPullAsync(brpc::Controller *cntl, + const BnStatisticsPullRequest *request, + BnStatisticsPullResponse *response, + Callback done) const { + new Call(bn_statistics_pull_dp_, + channel_, cntl, request, response, std::move(done)); +} + } // namespace tensornet diff --git a/core/ps/ps_remote_server.h b/core/ps/ps_remote_server.h index 9f1c189..e493d43 100644 --- a/core/ps/ps_remote_server.h +++ b/core/ps/ps_remote_server.h @@ -51,6 +51,16 @@ class PsRemoteServer : public PsServerInterface { DatasetPullResponse *response, Callback done) const override; + virtual void BnStatisticsPushAsync(brpc::Controller *cntl, + const BnStatisticsPushRequest *request, + BnStatisticsPushResponse *response, + Callback done) const override; + + virtual void BnStatisticsPullAsync(brpc::Controller *cntl, + const BnStatisticsPullRequest *request, + BnStatisticsPullResponse *response, + Callback done) const override; + private: std::shared_ptr channel_; @@ -58,6 +68,8 @@ class PsRemoteServer : public PsServerInterface { const google::protobuf::MethodDescriptor* sparse_push_dp_ = nullptr; const google::protobuf::MethodDescriptor* dense_push_pull_dp_ = nullptr; const google::protobuf::MethodDescriptor* dataset_pull_dp_ = nullptr; + const google::protobuf::MethodDescriptor* bn_statistics_push_dp_ = nullptr; + const google::protobuf::MethodDescriptor* bn_statistics_pull_dp_ = nullptr; }; } // namespace tensornet diff --git a/core/ps/ps_server_interface.h b/core/ps/ps_server_interface.h index d6ecdf6..88f4cd6 100644 --- a/core/ps/ps_server_interface.h +++ b/core/ps/ps_server_interface.h @@ -53,6 +53,16 @@ class PsServerInterface { DatasetPullResponse *response, Callback done) const = 0; + virtual void BnStatisticsPushAsync(brpc::Controller *cntl, + const BnStatisticsPushRequest *request, + BnStatisticsPushResponse *response, + Callback done) const = 0; + + virtual void BnStatisticsPullAsync(brpc::Controller *cntl, + const BnStatisticsPullRequest *request, + BnStatisticsPullResponse *response, + Callback done) const = 0; + private: typedef PsServerInterface ME; }; diff --git a/core/ps/ps_service_impl.cc b/core/ps/ps_service_impl.cc index 140df7e..6648e12 100644 --- a/core/ps/ps_service_impl.cc +++ b/core/ps/ps_service_impl.cc @@ -76,4 +76,30 @@ void PsServiceImpl::DatasetPull(google::protobuf::RpcController* cntl_base, [done]() { done->Run(); }); } +void PsServiceImpl::BnStatisticsPush(google::protobuf::RpcController* cntl_base, + const BnStatisticsPushRequest* request, + BnStatisticsPushResponse* response, + google::protobuf::Closure* done) { + brpc::Controller* cntl = static_cast(cntl_base); + + auto* cluster = PsCluster::Instance(); + const auto* si = cluster->GetServer(cluster->Rank()); + + si->BnStatisticsPushAsync(cntl, request, response, + [done]() { done->Run(); }); +} + +void PsServiceImpl::BnStatisticsPull(google::protobuf::RpcController* cntl_base, + const BnStatisticsPullRequest* request, + BnStatisticsPullResponse* response, + google::protobuf::Closure* done) { + brpc::Controller* cntl = static_cast(cntl_base); + + auto* cluster = PsCluster::Instance(); + const auto* si = cluster->GetServer(cluster->Rank()); + + si->BnStatisticsPullAsync(cntl, request, response, + [done]() { done->Run(); }); +} + } // end of namespace tensornet diff --git a/core/ps/ps_service_impl.h b/core/ps/ps_service_impl.h index 6838863..bb7a224 100644 --- a/core/ps/ps_service_impl.h +++ b/core/ps/ps_service_impl.h @@ -43,6 +43,16 @@ class PsServiceImpl : public PsService { const DatasetPullRequest* request, DatasetPullResponse* response, google::protobuf::Closure* done); + + virtual void BnStatisticsPush(google::protobuf::RpcController* cntl_base, + const BnStatisticsPushRequest* request, + BnStatisticsPushResponse* response, + google::protobuf::Closure* done); + + virtual void BnStatisticsPull(google::protobuf::RpcController* cntl_base, + const BnStatisticsPullRequest* request, + BnStatisticsPullResponse* response, + google::protobuf::Closure* done); }; } // end of namespace tensornet diff --git a/core/ps/table/bn_table.cc b/core/ps/table/bn_table.cc new file mode 100644 index 0000000..b5d842b --- /dev/null +++ b/core/ps/table/bn_table.cc @@ -0,0 +1,228 @@ +// Copyright (c) 2020, Qihoo, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/ps/table/bn_table.h" + +#include +#include + +#include +#include +#include +#include + +#include +#include "core/ps_interface/ps_raw_interface.h" +#include "core/utility/file_io.h" +#include "core/ps/optimizer/data_struct.h" + +namespace tensornet { + +BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count) + : shard_num_(shard_num) + , self_shard_id_(self_shard_id) + , name_(name) + , synchronized_(synchronized) + , moment_(moment) + , max_count_(max_count) + , bn_size_(bn_size) { + total_sum_.setZero(bn_size); + total_sum_err_.setZero(bn_size); + total_squared_sum_.setZero(bn_size); + total_squared_sum_err_.setZero(bn_size); + total_count_.setZero(bn_size); + inc_sum_.setZero(bn_size); + inc_squared_sum_.setZero(bn_size); + inc_count_.setZero(bn_size); + mu_ = std::make_unique(); +} + +void BnTable::SetHandle(uint32_t handle) { + CHECK(handle_ == 0) << "bn table handle has already set:" << handle_; + + handle_ = handle; +} + +void BnTable::Append(butil::IOBuf& bn_statistics_buf, bool isLocal) { + const std::lock_guard lock(*mu_); + Eigen::ArrayXd acc_sum = Eigen::ArrayXd::Zero(bn_size_); + Eigen::ArrayXd acc_squared_sum = Eigen::ArrayXd::Zero(bn_size_); + Eigen::ArrayXd acc_count = Eigen::ArrayXd::Zero(bn_size_); + + bn_statistics_buf.cutn(acc_sum.data(), acc_sum.size() * sizeof(double)); + bn_statistics_buf.cutn(acc_squared_sum.data(), acc_squared_sum.size() * sizeof(double)); + bn_statistics_buf.cutn(acc_count.data(), acc_count.size() * sizeof(double)); + CHECK_EQ(bn_statistics_buf.size(), 0); + + if(isLocal){ + inc_sum_ += acc_sum; + inc_squared_sum_ += acc_squared_sum; + inc_count_ += acc_count; + } + + uint64_t cur_count = static_cast(total_count_.maxCoeff()); + + if(max_count_ > 0 && cur_count > max_count_) { + uint64_t acc_count_num = static_cast(acc_count.maxCoeff()); + double ratio = (double) acc_count_num / cur_count; + total_sum_ *= (1 - (1 - moment_) * ratio); + TotalSumAcc((1 - moment_) * ratio * acc_sum); + total_squared_sum_ *= (1 - (1 - moment_) * ratio); + TotalSquareSumAcc((1 - moment_) * ratio * acc_squared_sum); + } else { + TotalSumAcc(acc_sum); + TotalSquareSumAcc(acc_squared_sum); + total_count_ += acc_count; + } +} + +void BnTable::TotalSquareSumAcc(Eigen::ArrayXd acc){ + Eigen::ArrayXd y = acc - total_squared_sum_err_; + Eigen::ArrayXd t = total_squared_sum_ + y; + total_squared_sum_err_ = (t - total_squared_sum_) - y; + total_squared_sum_ = t; +} + +void BnTable::TotalSumAcc(Eigen::ArrayXd acc){ + Eigen::ArrayXd y = acc - total_sum_err_; + Eigen::ArrayXd t = total_sum_ + y; + total_sum_err_ = (t - total_sum_) - y; + total_sum_ = t; +} + + +std::tuple BnTable::GetMoments() { + Eigen::ArrayXf global_mean = DivideNoNan(total_sum_, total_count_); + Eigen::ArrayXf global_squared_mean = DivideNoNan(total_squared_sum_, total_count_); + Eigen::ArrayXf global_var = (global_squared_mean - global_mean.square()).max(0.0); + return std::make_tuple(global_mean, global_var); +} + +void BnTable::GetStatistics(const BnStatisticsPullRequest* req, butil::IOBuf& bn_statistics_buf, BnStatisticsPullResponse* resp) { + resp->set_table_handle(req->table_handle()); + bn_statistics_buf.append(total_sum_.data(), total_sum_.size() * sizeof(double)); + bn_statistics_buf.append(total_squared_sum_.data(), total_squared_sum_.size() * sizeof(double)); + bn_statistics_buf.append(total_count_.data(), total_count_.size() * sizeof(double)); +} + +void BnTable::GetIncStatistics(butil::IOBuf& bn_statistics_buf) { + bn_statistics_buf.append(inc_sum_.data(), inc_sum_.size() * sizeof(double)); + bn_statistics_buf.append(inc_squared_sum_.data(), inc_squared_sum_.size() * sizeof(double)); + bn_statistics_buf.append(inc_count_.data(), inc_count_.size() * sizeof(double)); + inc_sum_.setZero(); + inc_squared_sum_.setZero(); + inc_count_.setZero(); +} + + +void BnTable::Refresh() { + total_sum_.setZero(); + total_squared_sum_.setZero(); + total_count_.setZero(); + + inc_sum_.setZero(); + inc_squared_sum_.setZero(); + inc_count_.setZero(); +} + + +Eigen::ArrayXf BnTable::DivideNoNan(const Eigen::ArrayXd& numerator, const Eigen::ArrayXd& denominator) { + Eigen::ArrayXd result = numerator; + for (int i = 0; i < numerator.size(); ++i) { + if (!std::isnan(denominator(i)) && denominator(i) != 0.0) { + result(i) = numerator(i) / denominator(i); + } else { + result(i) = 0.0; + } + } + return result.cast(); +} + +void BnTable::PrintDetail(){ + std::cout << "Array elements for handle: " << handle_ << " Elements: "; + for (int i = 0; i < total_squared_sum_.size(); ++i) { + std::cout << total_squared_sum_(i) << " "; + } + std::cout << std::endl; +} + +void BnTable::Load(const std::string& filepath) { + + std::string file = filepath + "/bn_table/"; + file += std::to_string(GetHandle()); + + FileReaderSource reader_source(file, FCT_ZLIB); + boost::iostreams::stream in_stream(reader_source); + in_stream.iword(SERIALIZE_FMT_ID) = SF_BIN; + + int bn_size = 0; + in_stream.read(reinterpret_cast(&bn_size), sizeof(bn_size)); + + for( int i = 0; i < bn_size; i++) { + in_stream.read(reinterpret_cast(&total_sum_[i]), sizeof(total_sum_[i])); + in_stream.read(reinterpret_cast(&total_squared_sum_[i]), sizeof(total_squared_sum_[i])); + in_stream.read(reinterpret_cast(&total_count_[i]), sizeof(total_count_[i])); + } +} + +void BnTable::Save(const std::string& filepath) { + + std::string file = filepath + "/bn_table/"; + + file += std::to_string(GetHandle()); + + FileWriterSink writer_sink(file, FCT_ZLIB); + + boost::iostreams::stream out_stream(writer_sink); + out_stream.iword(SERIALIZE_FMT_ID) = SF_BIN; + + out_stream.write(reinterpret_cast(&bn_size_), sizeof(bn_size_)); + + for( int i = 0; i < bn_size_; i++) { + out_stream.write(reinterpret_cast(&total_sum_[i]), sizeof(total_sum_[i])); + out_stream.write(reinterpret_cast(&total_squared_sum_[i]), sizeof(total_squared_sum_[i])); + out_stream.write(reinterpret_cast(&total_count_[i]), sizeof(total_count_[i])); + } + out_stream.flush(); + +} + +BnTableRegistry* BnTableRegistry::Instance() { + static BnTableRegistry instance; + return &instance; +} + +BnTable* BnTableRegistry::Get(uint32_t table_handle) { + CHECK(table_handle < tables_.size()) + << " table_handle:" << table_handle << " table size:" << tables_.size(); + return tables_[table_handle]; +} + +uint32_t BnTableRegistry::Register(BnTable* table) { + const std::lock_guard lock(mu_); + + uint32_t table_handle = tables_.size(); + tables_.emplace_back(table); + return table_handle; +} + +BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count) { + BnTable* table = new BnTable(name, shard_num, self_shard_id, bn_size, sync, moment, max_count); + + table->SetHandle(BnTableRegistry::Instance()->Register(table)); + + return table; +} + +} // namespace tensornet diff --git a/core/ps/table/bn_table.h b/core/ps/table/bn_table.h new file mode 100644 index 0000000..7522145 --- /dev/null +++ b/core/ps/table/bn_table.h @@ -0,0 +1,110 @@ +// Copyright (c) 2020, Qihoo, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORNET_PS_TABLE_BN_TABLE_H_ +#define TENSORNET_PS_TABLE_BN_TABLE_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "core/ps_interface/ps_server.pb.h" + +namespace tensornet { + +class BnTable { +public: + BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count); + + ~BnTable() = default; + + void Append(butil::IOBuf& out_buf, bool isLocal); + + void GetStatistics(const BnStatisticsPullRequest* req, butil::IOBuf& out_buf, BnStatisticsPullResponse* resp); + + void GetIncStatistics(butil::IOBuf& out_buf); + + std::tuple GetMoments(); + std::tuple GetIncMoments(); + + Eigen::ArrayXf DivideNoNan(const Eigen::ArrayXd& numerator, const Eigen::ArrayXd& denominator); + + void TotalSumAcc(Eigen::ArrayXd acc_sum); + void TotalSquareSumAcc(Eigen::ArrayXd acc_square_sum); + void Save(const std::string& filepath); + void Load(const std::string& filepath); + + void PrintDetail(); + + void SetHandle(uint32_t handle); + + uint32_t GetHandle() const { + return handle_; + } + + void Refresh(); + +private: + int shard_num_ = 0; + int self_shard_id_ = 0; + uint32_t handle_ = 0; + std::string name_; + uint32_t bn_size_ = 0; + bool synchronized_ = false; + float moment_ = 0.0; + uint64_t max_count_ = 0; + Eigen::ArrayXd total_sum_; + Eigen::ArrayXd total_sum_err_; + Eigen::ArrayXd total_squared_sum_; + Eigen::ArrayXd total_squared_sum_err_; + Eigen::ArrayXd total_count_; + Eigen::ArrayXd inc_sum_; + Eigen::ArrayXd inc_squared_sum_; + Eigen::ArrayXd inc_count_; + std::unique_ptr mu_; +}; + +class BnTableRegistry { +public: + ~BnTableRegistry() { + for (auto table : tables_) { + delete table; + } + } + + static BnTableRegistry* Instance(); + + BnTable* Get(uint32_t table_handle); + + uint32_t Register(BnTable* table); + +private: + BnTableRegistry() { } + +private: + std::mutex mu_; + std::vector tables_; +}; + +BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count); + +} // namespace tensornet + +#endif // TENSORNET_PS_TABLE_BN_TABLE_H_ diff --git a/core/ps_interface/ps_server.proto b/core/ps_interface/ps_server.proto index d772fcf..99c0aa0 100644 --- a/core/ps_interface/ps_server.proto +++ b/core/ps_interface/ps_server.proto @@ -43,9 +43,32 @@ message DatasetPullResponse { bytes dataset_info = 3; }; +message BnStatisticsPushRequest { + uint32 req_shard_id = 1; + uint32 table_handle = 2; +}; + +message BnStatisticsPushResponse { + uint32 resp_shard_id = 1; + uint32 table_handle = 2; +}; + +message BnStatisticsPullRequest { + uint32 req_shard_id = 1; + uint32 table_handle = 2; +}; + +message BnStatisticsPullResponse { + uint32 resp_shard_id = 1; + uint32 table_handle = 2; +}; + + service PsService { rpc SparsePull(SparsePullRequest) returns (SparsePullResponse); rpc SparsePush(SparsePushRequest) returns (SparsePushResponse); rpc DensePushPull(DensePushPullRequest) returns (DensePushPullResponse); rpc DatasetPull(DatasetPullRequest) returns (DatasetPullResponse); + rpc BnStatisticsPush(BnStatisticsPushRequest) returns (BnStatisticsPushResponse); + rpc BnStatisticsPull(BnStatisticsPullRequest) returns (BnStatisticsPullResponse); }; diff --git a/tensornet/core/gen_bn_table_ops.py b/tensornet/core/gen_bn_table_ops.py new file mode 100644 index 0000000..8903bbe --- /dev/null +++ b/tensornet/core/gen_bn_table_ops.py @@ -0,0 +1,240 @@ +"""Python wrappers around TensorFlow ops. + +This file is MACHINE GENERATED! Do not edit. +Original C++ source file: dense_table_ops.cc +""" + +import collections + +from tensorflow.python import pywrap_tfe as pywrap_tfe +from tensorflow.python.eager import context as _context +from tensorflow.python.eager import core as _core +from tensorflow.python.eager import execute as _execute +from tensorflow.python.framework import dtypes as _dtypes + +from tensorflow.python.framework import op_def_registry as _op_def_registry +from tensorflow.python.framework import ops as _ops +from tensorflow.python.framework import op_def_library as _op_def_library +from tensorflow.python.util.deprecation import deprecated_endpoints +from tensorflow.python.util import dispatch as _dispatch +from tensorflow.python.util.tf_export import tf_export + + +@_dispatch.add_dispatch_list +@tf_export('update_moments') +def update_moments(vars, table_handle, name=None): + r"""set bn table mean var count + + Args: + vars: A list of at least 1 `Tensor` objects with type `resource`. + table_handle: An `int`. + name: A name for the operation (optional). + + Returns: + The created Operation. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx._context_handle, tld.device_name, "UpdateMoments", name, + tld.op_callbacks, vars, "table_handle", table_handle) + return _result + except _core._FallbackException: + try: + return update_moments_eager_fallback( + vars, table_handle=table_handle, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + result = _dispatch.dispatch( + update_moments, vars=vars, table_handle=table_handle, + name=name) + if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return result + raise + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + # Add nodes to the TensorFlow graph. + if not isinstance(vars, (list, tuple)): + raise TypeError( + "Expected list for 'vars' argument to " + "'update_moments' Op, not %r." % vars) + _attr_N = len(vars) + table_handle = _execute.make_int(table_handle, "table_handle") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "UpdateMoments", vars=vars, table_handle=table_handle, name=name) + except (TypeError, ValueError): + result = _dispatch.dispatch( + update_moments, vars=vars, table_handle=table_handle, name=name) + if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return result + raise + return _op +UpdateMoments = tf_export("raw_ops.UpdateMoments")(_ops.to_raw_op(update_moments)) + + +def update_moments_eager_fallback(vars, table_handle, name, ctx): + if not isinstance(vars, (list, tuple)): + raise TypeError( + "Expected list for 'vars' argument to " + "'bn_statistics_push' Op, not %r." % vars) + _attr_N = len(vars) + table_handle = _execute.make_int(table_handle, "table_handle") + vars = _ops.convert_n_to_tensor(vars, _dtypes.resource) + _inputs_flat = list(vars) + _attrs = ("table_handle", table_handle, "N", _attr_N) + _result = _execute.execute(b"BnVarsSet", 0, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + _result = None + return _result + + +@_dispatch.add_dispatch_list +@tf_export('bn_statistics_push') +def bn_statistics_push(vars, table_handle, synchronized, name=None): + r"""push pull variable from parameter server + + Args: + vars: A list of at least 1 `Tensor` objects with type `resource`. + grads: A list with the same length as `vars` of `Tensor` objects with type `float32`. + table_handle: An `int`. + name: A name for the operation (optional). + + Returns: + The created Operation. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx._context_handle, tld.device_name, "BnStatisticsPush", name, + tld.op_callbacks, vars, "table_handle", table_handle, "synchronized", synchronized) + return _result + except _core._FallbackException: + try: + return bn_statistics_push_eager_fallback( + vars, table_handle=table_handle, synchronized=synchronized, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + result = _dispatch.dispatch( + bn_statistics_push, vars=vars, table_handle=table_handle, synchronized=synchronized, name=name) + if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return result + raise + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + # Add nodes to the TensorFlow graph. + if not isinstance(vars, (list, tuple)): + raise TypeError( + "Expected list for 'vars' argument to " + "'bn_statistics_push' Op, not %r." % vars) + _attr_N = len(vars) + table_handle = _execute.make_int(table_handle, "table_handle") + synchronized = _execute.make_bool(synchronized, "synchronized") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "BnStatisticsPush", vars=vars, + table_handle=table_handle, synchronized=synchronized, name=name) + except (TypeError, ValueError): + result = _dispatch.dispatch( + bn_statistics_push, vars=vars, + table_handle=table_handle, synchronized=synchronized, name=name) + if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return result + raise + return _op +BnStatisticsPush = tf_export("raw_ops.BnStatisticsPush")(_ops.to_raw_op(bn_statistics_push)) + + +def bn_statistics_push_eager_fallback(vars, table_handle, synchronized, name, ctx): + if not isinstance(vars, (list, tuple)): + raise TypeError( + "Expected list for 'vars' argument to " + "'bn_statistics_push' Op, not %r." % vars) + _attr_N = len(vars) + table_handle = _execute.make_int(table_handle, "table_handle") + synchronized = _execute.make_bool(synchronized, "synchronized") + vars = _ops.convert_n_to_tensor(vars, _dtypes.resource) + _inputs_flat = list(vars) + _attrs = ("table_handle", table_handle, "N", _attr_N, "synchronized", synchronized) + _result = _execute.execute(b"BnStatisticsPush", 0, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + _result = None + return _result + +@_dispatch.add_dispatch_list +@tf_export('bn_statistics_pull') +def bn_statistics_pull(vars, table_handle, name=None): + r"""push pull variable from parameter server + + Args: + vars: A list of at least 1 `Tensor` objects with type `resource`. + grads: A list with the same length as `vars` of `Tensor` objects with type `float32`. + table_handle: An `int`. + name: A name for the operation (optional). + + Returns: + The created Operation. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx._context_handle, tld.device_name, "BnStatisticsPull", name, + tld.op_callbacks, vars, "table_handle", table_handle) + return _result + except _core._FallbackException: + try: + return bn_statistics_pull_eager_fallback( + vars, table_handle=table_handle, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + result = _dispatch.dispatch( + bn_statistics_pull, vars=vars, table_handle=table_handle, name=name) + if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return result + raise + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + # Add nodes to the TensorFlow graph. + if not isinstance(vars, (list, tuple)): + raise TypeError( + "Expected list for 'vars' argument to " + "'bn_statistics_pull' Op, not %r." % vars) + _attr_N = len(vars) + table_handle = _execute.make_int(table_handle, "table_handle") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "BnStatisticsPull", vars=vars, + table_handle=table_handle, name=name) + except (TypeError, ValueError): + result = _dispatch.dispatch( + bn_statistics_pull, vars=vars, + table_handle=table_handle, name=name) + if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return result + raise + return _op +BnStatisticsPull = tf_export("raw_ops.BnStatisticsPull")(_ops.to_raw_op(bn_statistics_pull)) + +def bn_statistics_pull_eager_fallback(vars, table_handle, name, ctx): + if not isinstance(vars, (list, tuple)): + raise TypeError( + "Expected list for 'vars' argument to " + "'bn_statistics_pull' Op, not %r." % vars) + _attr_N = len(vars) + table_handle = _execute.make_int(table_handle, "table_handle") + vars = _ops.convert_n_to_tensor(vars, _dtypes.resource) + _inputs_flat = list(vars) + _attrs = ("table_handle", table_handle, "N", _attr_N) + _result = _execute.execute(b"BnStatisticsPull", 0, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + _result = None + return _result diff --git a/tensornet/layers/__init__.py b/tensornet/layers/__init__.py index 1663b49..bba21a5 100644 --- a/tensornet/layers/__init__.py +++ b/tensornet/layers/__init__.py @@ -14,3 +14,4 @@ from .embedding_features import EmbeddingFeatures from .sequence_embedding_features import SequenceEmbeddingFeatures +from .normalization_layer import TNBatchNormalization diff --git a/tensornet/layers/normalization_layer.py b/tensornet/layers/normalization_layer.py new file mode 100644 index 0000000..0749f85 --- /dev/null +++ b/tensornet/layers/normalization_layer.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +import tensorflow as tf +import tensornet as tn + +from tensornet.core import gen_bn_table_ops +from tensorflow.keras import initializers +from tensorflow.keras import regularizers +from tensorflow.keras import constraints +from tensorflow.python.keras.utils import tf_utils +from tensorflow.keras.layers import Layer +from tensorflow.python.training import moving_averages +from tensorflow.python.ops import variable_scope, array_ops + + +class TNBatchNormalization(Layer): + """ + Reference: https://github.com/keras-team/keras/blob/v3.5.0/keras/src/layers/normalization/batch_normalization.py + + Args: + center, scale, epsilon are the same as original batch normalization layer. + momentum: same defination of original batch normalization, but it's for bn statistics, not original moving_mean, moving_var + synchronized: Whether bn statistics(sum, squared sum, count) should be passed to other tensornet rank during training. + If set to False, on train end, rank 0 will pull all statistics from other rank and calculate moving_mean and moving var, only once. + If set to True, with 'sync_freq' argument, every 'sync_freq' batches, incremental bn statistics will be broadcast to all other ranks. + sync_freq: frequency that bn statistics will be sent to other ranks(based on batches). Only should be used when 'synchronized' is True + max_count: Threshold that to avoid bn statistics overflow. Note that: it's record number, not batch number. This is an empirical parameter that needs to be adjusted based on the size of the training data. + """ + def __init__(self, center=True, scale=True, epsilon=1e-5, momentum=0.99, name=None, synchronized=False, sync_freq=1,max_count=100000,**kwargs): + super(TNBatchNormalization, self).__init__(**kwargs) + self.center = center + self.scale = scale + self.epsilon = epsilon + self.momentum = momentum + self.moments_axes = [] + self.apply_axis = [] + self.gamma, self.beta = None, None + self.beta_initializer = initializers.get('zeros') + self.gamma_initializer = initializers.get('ones') + self.moving_mean_initializer = initializers.get('zeros') + self.moving_variance_initializer = initializers.get('ones') + self.local_count_initializer = initializers.get('zeros') + self.local_sum_initializer = initializers.get('zeros') + self.local_squared_num_initializer = initializers.get('zeros') + self.beta_regularizer = regularizers.get(None) + self.gamma_regularizer = regularizers.get(None) + self.beta_constraint = constraints.get(None) + self.gamma_constraint = constraints.get(None) + self.synchronized = synchronized + self.sync_freq = sync_freq + self.batch_counter = tf.Variable(0, name="batch_counter") + self.max_count = max_count + + + def build(self, input_shape): + input_rank = len(input_shape) + self.moments_axes = list(range(input_rank - 1)) + self.apply_axis = input_shape[-1:] + self.params_reshape = [1 for _ in range( + 1, input_rank - 1)] + [input_shape[-1]] + + if self.scale: + self.gamma = self.add_weight(shape=self.apply_axis, name='gamma', initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, constraint=self.gamma_constraint) + + if self.center: + self.beta = self.add_weight(shape=self.apply_axis, name='beta', initializer=self.beta_initializer, + regularizer=self.beta_regularizer, constraint=self.beta_constraint) + + self.moving_mean = self.add_weight( + shape=self.apply_axis, + name="moving_mean", + initializer=self.moving_mean_initializer, + trainable=False) + + self.moving_variance = self.add_weight( + shape=self.apply_axis, + name="moving_variance", + initializer=self.moving_variance_initializer, + trainable=False) + + self.local_count = self.add_weight( + shape=self.apply_axis, + name="local_count", + initializer=self.local_count_initializer, + trainable=False + ) + + self.local_sum = self.add_weight( + shape=self.apply_axis, + name="local_sum", + initializer=self.local_sum_initializer, + trainable=False) + + self.local_squared_sum = self.add_weight( + shape=self.apply_axis, + name="local_squared_sum", + initializer=self.local_squared_num_initializer, + trainable=False) + + self.bn_table_handle = tn.core.create_bn_table(self.name, self.apply_axis[0], self.synchronized, self.momentum, self.max_count) + + + def call(self, inputs, training=None): + + @tf.function + def _increment_and_check_count(): + self.batch_counter.assign_add(1) + if tf.equal(self.batch_counter, self.sync_freq): + self.bn_statistics_push(True) + self.batch_counter.assign(0) + else: + self.bn_statistics_push(False) + + if training: + local_count_sample = tf.ones_like(inputs, name="count") + self.local_sum.assign(tf.reduce_sum(inputs, axis=self.moments_axes)) + self.local_squared_sum.assign(tf.reduce_sum(tf.square(inputs), axis=self.moments_axes)) + self.local_count.assign(tf.reduce_sum(local_count_sample, axis=self.moments_axes)) + if self.synchronized: + _increment_and_check_count() + else: + self.bn_statistics_push(False) + self.update_moments() + + mean = self.moving_mean + var = self.moving_variance + + outputs = tf.nn.batch_normalization(x=inputs, mean=mean, variance=var, offset=self.beta, scale=self.gamma, variance_epsilon=self.epsilon) + + return outputs + + def update_moments(self): + gen_bn_table_ops.update_moments([self.moving_mean.handle, self.moving_variance.handle], table_handle=self.bn_table_handle) + + def bn_statistics_push(self, synchronized): + gen_bn_table_ops.bn_statistics_push([self.local_sum.handle, self.local_squared_sum.handle, self.local_count.handle], table_handle=self.bn_table_handle, synchronized=synchronized) + + def bn_statistics_pull(self): + # if sync_freq is greater than 1, force sync statistics once at the end of training + if not self.synchronized or self.sync_freq > 1: + self.batch_counter.assign(0) + gen_bn_table_ops.bn_statistics_pull([self.moving_mean.handle, self.moving_variance.handle], table_handle=self.bn_table_handle) + + def save_bn_table(self, filepath): + return tn.core.save_bn_table(self.bn_table_handle, filepath) + + def load_bn_table(self, filepath): + return tn.core.load_bn_table(self.bn_table_handle, filepath) + diff --git a/tensornet/model/Model.py b/tensornet/model/Model.py index 0e4725f..ee3462d 100644 --- a/tensornet/model/Model.py +++ b/tensornet/model/Model.py @@ -177,6 +177,10 @@ def save_weights(self, filepath, overwrite=True, save_format=None, dt="", root=T layer.save_sparse_table(cp_dir, mode) elif isinstance(layer, tn.layers.SequenceEmbeddingFeatures): layer.save_sparse_table(cp_dir, mode) + elif isinstance(layer, tn.layers.TNBatchNormalization): + if tn.core.self_shard_id() == 0: + layer.bn_statistics_pull() + layer.save_bn_table(cp_dir) if self.optimizer: self.optimizer.save_dense_table(cp_dir) @@ -219,6 +223,8 @@ def load_weights(self, filepath, by_name=False, skip_mismatch=False, include_dt= layer.load_sparse_table(cp_dir, mode) elif isinstance(layer, tn.layers.SequenceEmbeddingFeatures): layer.load_sparse_table(cp_dir, mode) + elif isinstance(layer, tn.layers.TNBatchNormalization): + layer.load_bn_table(cp_dir) # dense weight if self.optimizer: diff --git a/tensornet/version.py b/tensornet/version.py index 495e451..5bcbc5f 100644 --- a/tensornet/version.py +++ b/tensornet/version.py @@ -1 +1 @@ -VERSION = "0.1.3.post2" +VERSION = "0.2.0.rc"