Skip to content

Commit

Permalink
[Hook] Add 'before_create_session' interface to SessionRunHook.
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <[email protected]>
  • Loading branch information
JackMoriarty committed Apr 25, 2024
1 parent a4489e3 commit 7e78253
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorflow/python/training/monitored_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,8 @@ def __init__(self, session_creator, hooks, stop_grace_period_secs):
def create_session(self):
"""Creates a coordinated session."""
# Keep the tf_sess for unit testing.
for hook in self._hooks:
hook.before_create_session()
self.tf_sess = self._session_creator.create_session()
# We don't want coordinator to suppress any exception.
self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
Expand Down Expand Up @@ -1027,6 +1029,7 @@ class MonitoredSession(_MonitoredSession):
in given order:
* calls `hook.begin()` for each given hook
* calls `hook.before_create_session()`
* finalizes the graph via `scaffold.finalize()`
* create session
* initializes the model via initialization ops provided by `Scaffold`
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/training/session_run_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ def begin(self):
"""
pass

def before_create_session(self):
"""Called before new TensorFlow session is created.
This has two essential differences with the situation in which `begin` is
called:
* Do not modify the graph in this method, ops should not be added to graph.
The modification of the graph should take place within the begin
interface.
* This method will also be called prior to the recovery of a wrapped
session, not just at the beginning of the overall session.
"""
pass

def after_create_session(self, session, coord): # pylint: disable=unused-argument
"""Called when new TensorFlow session is created.
Expand Down

0 comments on commit 7e78253

Please sign in to comment.