Skip to content

Commit

Permalink
[Op] Modify Iterator Impl
Browse files Browse the repository at this point in the history
Signed-off-by: JunqiHu <[email protected]>
  • Loading branch information
Mesilenceki committed Oct 26, 2023
1 parent daaf7a6 commit 0c76d38
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,7 @@ def main():
True, 'star')

set_build_var(environ_cp, 'TF_NEED_ELASTIC', 'ELASTIC TRAINING', 'with_elastic_support',
True, 'elastic')
False, 'elastic')

set_build_var(environ_cp, 'TF_ENABLE_PMEM', 'PMEM', 'with_pmem_support',
False, 'pmem')
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/contrib/elastic_grpc_server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ cc_library(
tf_cc_test(
name = "elastic_grpc_test",
size = "small",
srcs = ["elastic_grpc_server_lib_test.cc"],
srcs = select({"//tensorflow:with_elastic_support": ["elastic_grpc_server_lib_test.cc"],
"//conditions:default": []}),
deps = [
":elastic_grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ load(
"tf_additional_numa_deps",
"tf_additional_numa_lib_defines",
"tf_additional_star_lib_defines",
"tf_additional_elastic_server_lib_defines",
"tf_additional_api_compatible_defines",
"tf_additional_pmem_lib_defines",
"tf_additional_test_deps",
Expand Down Expand Up @@ -1441,6 +1442,7 @@ tf_cc_test(
cc_library(
name = "ops",
visibility = ["//visibility:public"],
defines = tf_additional_elastic_server_lib_defines(),
deps = [
":array_ops_op_lib",
":parquet_ops_op_lib",
Expand Down Expand Up @@ -2562,7 +2564,8 @@ LIB_INTERNAL_DEFINES = (
tf_additional_gdr_lib_defines() +
tf_additional_numa_lib_defines() +
tf_additional_star_lib_defines() +
tf_additional_pmem_lib_defines()
tf_additional_pmem_lib_defines() +
tf_additional_elastic_server_lib_defines()
)

cc_library(
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ load(
"transitive_hdrs",
)

load(
"//tensorflow/core/platform:default/build_config.bzl",
"tf_additional_elastic_server_lib_defines",
)

package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
Expand Down Expand Up @@ -1119,6 +1124,7 @@ tf_kernel_library(
name = "iterator_ops",
srcs = ["iterator_ops.cc"],
hdrs = ["iterator_ops.h"],
defines = tf_additional_elastic_server_lib_defines(),
deps = [
":captured_function",
":dataset_utils",
Expand Down
17 changes: 10 additions & 7 deletions tensorflow/core/kernels/data/iterator_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("recoverable", &recoverable_));
}

// The resource is deleted from the resource manager only when it is private
Expand Down Expand Up @@ -309,11 +308,11 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
}

ResourceMgr* mgr = context->resource_manager();
if (recoverable_ == false) {
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), false));
} else {
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), true));
}
#ifdef TENSORFLOW_USE_ELASTIC_SERVER
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), true));
#else
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), false));
#endif

IteratorResource* resource;
OP_REQUIRES_OK(
Expand Down Expand Up @@ -788,7 +787,11 @@ class OneShotIteratorOp : public AsyncOpKernel {

Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
ContainerInfo* cinfo) {
TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def(), true));
#ifdef TENSORFLOW_USE_ELASTIC_SERVER
TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def(), true));
#else
TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def(), false));
#endif

FunctionLibraryRuntime* flr;
std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/kernels/data/iterator_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ class IteratorHandleOp : public OpKernel {
std::vector<PartialTensorShape> output_shapes_;
const int graph_def_version_;
string name_;
bool recoverable_;
};

// Like IteratorHandleOp, but creates handles which are never shared, and does
Expand Down
11 changes: 0 additions & 11 deletions tensorflow/core/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,24 +555,13 @@ REGISTER_OP("Iterator")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);

#ifndef TF_API_COMPATIBLE_1150
REGISTER_OP("IteratorV2")
.Output("handle: resource")
.Attr("shared_name: string")
.Attr("container: string")
.Attr("recoverable: bool = false")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
#else
REGISTER_OP("IteratorV2")
.Output("handle: resource")
.Attr("shared_name: string")
.Attr("container: string")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
#endif

REGISTER_OP("AnonymousIterator")
.Output("handle: resource")
Expand Down

0 comments on commit 0c76d38

Please sign in to comment.