Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new inputs to pre-trained layers #421

Open
michelwi opened this issue Jan 11, 2021 · 2 comments
Open

Adding new inputs to pre-trained layers #421

michelwi opened this issue Jan 11, 2021 · 2 comments

Comments

@michelwi
Copy link
Collaborator

Problem Statement

I would like to extend an already existing layer with an additional input. In my example I have trained an attention-based encoder-decoder model and now I would like to add an external LM to the inputs of the Softmax layer:

My returnn config for training the checkpoint was:

'output_prob': { 'class': 'softmax',
                 'dropout': 0.0,
                 'from': ['readout'], # input size: 500
                 'loss': 'ce',
                 'target': 'classes'}, # output size: 10025

Now I continue training with this extended layer while importing the existing checkpoint with the preload_from_files mechanics:

'output_prob': { 'bias_init': 0,
                 'class': 'softmax',
                 'custom_param_importer': 'subset',
                 'dropout': 0.0,
                 'forward_weights_init': 0,
                 'from': ['lm_output_prob', 'readout'], # input size: 10025 + 500
                 'loss': 'ce',
                 'target': 'classes'}, # output size: 10025

First Attempt

I am on commit face0c3 wich I extended with the changes introduced in #412

The relevant variables in returnn/tf/util/basic.py function transform_param_axes_split_info_to_new_shape are

axes_split_info = [[10025, 500], [10025]]
new_shape = (500, 10025)
dim_diff = {10025: 10025}
new_parts = [10025, None]
Here is the Stack Trace I get
Unhandled exception <class 'AssertionError'> in thread <_MainThread(MainThread, started 140417611712256)>, proc 7117.

Thread current, main, <_MainThread(MainThread, started 140417611712256)>:
(Excluded thread.)

That were all threads.
EXCEPTION
Traceback (most recent call last):
  File "/work/asr4/michel/sandbow/returnn_meyer/rnn.py", line 11, in <module>
    line: main()
    locals:
      main = <local> <function main at 0x7fb57754c820>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/__main__.py", line 645, in main
    line: execute_main_task()
    locals:
      execute_main_task = <global> <function execute_main_task at 0x7fb57754c700>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/__main__.py", line 451, in execute_main_task
    line: engine.init_train_from_config(config, train_data, dev_data, eval_data)
    locals:
      engine = <global> <returnn.tf.engine.Engine object at 0x7fb577414430>
      engine.init_train_from_config = <global> <bound method Engine.init_train_from_config of <returnn.tf.engine.Engine object at 0x7fb577414430>>
      config = <global> <returnn.config.Config object at 0x7fb585c43a30>
      train_data = <global> <LibriSpeechCorpus 'train' epoch=1>
      dev_data = <global> <LibriSpeechCorpus 'dev' epoch=1>
      eval_data = <global> None
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/engine.py", line 1026, in init_train_from_config
    line: self.init_network_from_config(config)
    locals:
      self = <local> <returnn.tf.engine.Engine object at 0x7fb577414430>
      self.init_network_from_config = <local> <bound method Engine.init_network_from_config of <returnn.tf.engine.Engine object at 0x7fb577414430>>
      config = <local> <returnn.config.Config object at 0x7fb585c43a30>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/engine.py", line 1127, in init_network_from_config
    line: loader.load_now(session=self.tf_session)
    locals:
      loader = <local> CustomCheckpointLoader(filename='/work/asr3/michel/meyer/work/crnn/training/CRNNTrainingJob.737fYxfbkoCz/output/models/epoch.250', params_prefix='', load_if_prefix='', ignore_missing=True, network=<TFNetwork 'root' train=<tf.Tensor 'globals/train_flag:0' shape=() dtype=bool>>)
      loader.load_now = <local> <bound method CustomCheckpointLoader.load_now of CustomCheckpointLoader(filename='/work/asr3/michel/meyer/work/crnn/training/CRNNTrainingJob.737fYxfbkoCz/output/models/epoch.250', params_prefix='', load_if_prefix='', ignore_missing=True, network=<TFNetwork 'root' train=<tf.Tensor 'globals/train_f...
      session = <not found>
      self = <local> <returnn.tf.engine.Engine object at 0x7fb577414430>
      self.tf_session = <local> <tensorflow.python.client.session.Session object at 0x7fb57756fa30>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 3151, in load_now
    line: value.assign_var(var=var, session=session)
    locals:
      value = <local> <returnn.tf.network.CustomCheckpointLoader.VariableValue object at 0x7fb48556d880>
      value.assign_var = <local> <bound method CustomCheckpointLoader.VariableValue.assign_var of <returnn.tf.network.CustomCheckpointLoader.VariableValue object at 0x7fb48556d880>>
      var = <local> <tf.Variable 'output/rec/output_prob/W:0' shape=(10525, 10025) dtype=float32_ref>
      session = <local> <tensorflow.python.client.session.Session object at 0x7fb57756fa30>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 2832, in assign_var
    line: self.custom_param_importer.assign_var(var=var, session=session)
    locals:
      self = <local> <returnn.tf.network.CustomCheckpointLoader.VariableValue object at 0x7fb48556d880>
      self.custom_param_importer = <local> <CustomParamImporter 'subset' on layer 'output_prob'>
      self.custom_param_importer.assign_var = <local> <bound method CustomCheckpointLoader.CustomParamImporter.assign_var of <CustomParamImporter 'subset' on layer 'output_prob'>>
      var = <local> <tf.Variable 'output/rec/output_prob/W:0' shape=(10525, 10025) dtype=float32_ref>
      session = <local> <tensorflow.python.client.session.Session object at 0x7fb57756fa30>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 2776, in assign_var
    line: self.layer.set_param_values_by_dict(values_dict=values_dict, session=session)
    locals:
      self = <local> <CustomParamImporter 'subset' on layer 'output_prob'>
      self.layer = <local> <SoftmaxLayer output/'output_prob' out_type=Data(shape=(None, 10025), batch_dim_axis=1, batch_shape_meta=[T|'time:var:extern_data:classes',B,F|10025])>
      self.layer.set_param_values_by_dict = <local> <bound method LayerBase.set_param_values_by_dict of <SoftmaxLayer output/'output_prob' out_type=Data(shape=(None, 10025), batch_dim_axis=1, batch_shape_meta=[T|'time:var:extern_data:classes',B,F|10025])>>
      values_dict = <local> {'W': array([[ 0.07109621,  0.01965252,  0.01508382, ...,  0.01965284,
                                     0.02135382,  0.01964182],
                                   [-0.04140154, -0.12580605, -0.14898466, ..., -0.12578817,
                                    -0.12482522, -0.1258727 ],
                                   [ 0.11936312,  0.05931459,  0.03965379, ...,  0.05929022,
                                     0.06279352,  0.05...
      session = <local> <tensorflow.python.client.session.Session object at 0x7fb57756fa30>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/layers/base.py", line 897, in set_param_values_by_dict
    line: old_axes_splits = tf_util.transform_param_axes_split_info_to_new_shape(
    locals:
      old_axes_splits = <not found>
      tf_util = <global> <module 'returnn.tf.util.basic' from '/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/util/basic.py'>
      tf_util.transform_param_axes_split_info_to_new_shape = <global> <function transform_param_axes_split_info_to_new_shape at 0x7fb541d46ee0>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/util/basic.py", line 184, in transform_param_axes_split_info_to_new_shape
    line: assert new_parts[j] > 0, debug_name
    locals:
      new_parts = <local> [10025, -9525]
      j = <local> 1
      debug_name = <local> "param 'output/rec/output_prob/W:0'", len = 34
AssertionError: param 'output/rec/output_prob/W:0'

Ok, so in the loop of

for new_dim, parts in zip(new_shape, axes_split_info):
the first parts = [10025, 500] hits neither condition and only the second parts = [10025] sets dim_diff = {10025: 10025} so that we end up in
if any([d is None for d in new_parts]):

and then the heuristics fail.

Second Attempt

Then I added my case to the end of the loop of

for new_dim, parts in zip(new_shape, axes_split_info):

  for new_dim, parts in zip(new_shape, axes_split_info):
    if len(parts) == 1:
      dim_diff[parts[0]] = new_dim
    elif len(set(parts)) == 1:  # all the same
      if new_dim % len(parts) == 0:
        dim_diff[parts[0]] = new_dim // len(parts)  # just a heuristic
    elif sum(parts[1:]) == new_dim: # added one input in front (see heurustic below)
      dim_diff[parts[0]] = 0
      dim_diff.update({dim:dim for dim in parts[1:]})
    elif new_dim in parts: # all inputs are new except one
      dim_diff.update({dim:0 for dim in parts})
      dim_diff[new_dim] = new_dim

Now I get the following variables

axes_split_info = [[10025, 500], [10025]]
new_shape = (500, 10025)
dim_diff = {10025: 10025, 500: 500}
new_parts = [10025, 500]
And here is the new Stack Trace
Unhandled exception <class 'AssertionError'> in thread <_MainThread(MainThread, started 139845223368448)>, proc 28243.

Thread current, main, <_MainThread(MainThread, started 139845223368448)>:
(Excluded thread.)

That were all threads.
EXCEPTION
Traceback (most recent call last):
  File "/work/asr4/michel/sandbow/returnn_meyer/rnn.py", line 11, in <module>
    line: main()
    locals:
      main = <local> <function main at 0x7f303253b820>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/__main__.py", line 645, in main
    line: execute_main_task()
    locals:
      execute_main_task = <global> <function execute_main_task at 0x7f303253b700>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/__main__.py", line 451, in execute_main_task
    line: engine.init_train_from_config(config, train_data, dev_data, eval_data)
    locals:
      engine = <global> <returnn.tf.engine.Engine object at 0x7f3032404430>
      engine.init_train_from_config = <global> <bound method Engine.init_train_from_config of <returnn.tf.engine.Engine object at 0x7f3032404430>>
      config = <global> <returnn.config.Config object at 0x7f3040c32a30>
      train_data = <global> <LibriSpeechCorpus 'train' epoch=1>
      dev_data = <global> <LibriSpeechCorpus 'dev' epoch=1>
      eval_data = <global> None
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/engine.py", line 1026, in init_train_from_config
    line: self.init_network_from_config(config)
    locals:
      self = <local> <returnn.tf.engine.Engine object at 0x7f3032404430>
      self.init_network_from_config = <local> <bound method Engine.init_network_from_config of <returnn.tf.engine.Engine object at 0x7f3032404430>>
      config = <local> <returnn.config.Config object at 0x7f3040c32a30>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/engine.py", line 1127, in init_network_from_config
    line: loader.load_now(session=self.tf_session)
    locals:
      loader = <local> CustomCheckpointLoader(filename='/work/asr3/michel/meyer/work/crnn/training/CRNNTrainingJob.737fYxfbkoCz/output/models/epoch.250', params_prefix='', load_if_prefix='', ignore_missing=True, network=<TFNetwork 'root' train=<tf.Tensor 'globals/train_flag:0' shape=() dtype=bool>>)
      loader.load_now = <local> <bound method CustomCheckpointLoader.load_now of CustomCheckpointLoader(filename='/work/asr3/michel/meyer/work/crnn/training/CRNNTrainingJob.737fYxfbkoCz/output/models/epoch.250', params_prefix='', load_if_prefix='', ignore_missing=True, network=<TFNetwork 'root' train=<tf.Tensor 'globals/train_f...
      session = <not found>
      self = <local> <returnn.tf.engine.Engine object at 0x7f3032404430>
      self.tf_session = <local> <tensorflow.python.client.session.Session object at 0x7f2ffcd90fd0>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 3151, in load_now
    line: value.assign_var(var=var, session=session)
    locals:
      value = <local> <returnn.tf.network.CustomCheckpointLoader.VariableValue object at 0x7f2f40690670>
      value.assign_var = <local> <bound method CustomCheckpointLoader.VariableValue.assign_var of <returnn.tf.network.CustomCheckpointLoader.VariableValue object at 0x7f2f40690670>>
      var = <local> <tf.Variable 'output/rec/output_prob/W:0' shape=(10525, 10025) dtype=float32_ref>
      session = <local> <tensorflow.python.client.session.Session object at 0x7f2ffcd90fd0>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 2832, in assign_var
    line: self.custom_param_importer.assign_var(var=var, session=session)
    locals:
      self = <local> <returnn.tf.network.CustomCheckpointLoader.VariableValue object at 0x7f2f40690670>
      self.custom_param_importer = <local> <CustomParamImporter 'subset' on layer 'output_prob'>
      self.custom_param_importer.assign_var = <local> <bound method CustomCheckpointLoader.CustomParamImporter.assign_var of <CustomParamImporter 'subset' on layer 'output_prob'>>
      var = <local> <tf.Variable 'output/rec/output_prob/W:0' shape=(10525, 10025) dtype=float32_ref>
      session = <local> <tensorflow.python.client.session.Session object at 0x7f2ffcd90fd0>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 2776, in assign_var
    line: self.layer.set_param_values_by_dict(values_dict=values_dict, session=session)
    locals:
      self = <local> <CustomParamImporter 'subset' on layer 'output_prob'>
      self.layer = <local> <SoftmaxLayer output/'output_prob' out_type=Data(shape=(None, 10025), batch_dim_axis=1, batch_shape_meta=[T|'time:var:extern_data:classes',B,F|10025])>
      self.layer.set_param_values_by_dict = <local> <bound method LayerBase.set_param_values_by_dict of <SoftmaxLayer output/'output_prob' out_type=Data(shape=(None, 10025), batch_dim_axis=1, batch_shape_meta=[T|'time:var:extern_data:classes',B,F|10025])>>
      values_dict = <local> {'b': array([ 1.8866221 , -0.3177627 ,  3.6245446 , ..., -0.3159841 ,
                                   -0.31281865, -0.31379807], dtype=float32), 'W': array([[ 0.07109621,  0.01965252,  0.01508382, ...,  0.01965284,
                                     0.02135382,  0.01964182],
                                   [-0.04140154, -0.12580605, -0.14898466, ..., -0.12578817,
                                   ...
      session = <local> <tensorflow.python.client.session.Session object at 0x7f2ffcd90fd0>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/layers/base.py", line 897, in set_param_values_by_dict
    line: old_axes_splits = tf_util.transform_param_axes_split_info_to_new_shape(
    locals:
      old_axes_splits = <not found>
      tf_util = <global> <module 'returnn.tf.util.basic' from '/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/util/basic.py'>
      tf_util.transform_param_axes_split_info_to_new_shape = <global> <function transform_param_axes_split_info_to_new_shape at 0x7f2ffcdbdaf0>
  File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/util/basic.py", line 189, in transform_param_axes_split_info_to_new_shape
    line: assert new_parts[0] > 0, debug_name
    locals:
      new_parts = <local> [0, 500]
      debug_name = <local> "param 'output/rec/output_prob/W:0'", len = 34
AssertionError: param 'output/rec/output_prob/W:0'

Now the first parts = [10025, 500] hits my newly defined third condition and sets dim_diff = {10025: 0, 500: 500}, then the second parts = [10025] overwrites it to dim_diff = {10025: 10025, 500: 500}.
new_parts for the first dim should have been [0, 500] but due to the overwriting it is now [10025. 500] and we trigger

elif sum(new_parts) != new_dim:
which fortunately saves us and sets new_parts = [0, 500] .

Now only the assertion new_parts[0] > 0 fails as expected.

Questions

  • Is it save to allow the case new_parts[0] == 0? For me yes, the logic in copy_with_new_split_axes will work as I expect, but maybe transform_param_axes_split_info_to_new_shape is used elsewhere.
  • Can we rewrite the whole mechanism to not use a dict dim_diff but to infer the new shapes in order as we progress through the lists? This would alleviate the issues with square layers.
@albertz
Copy link
Member

albertz commented Jan 11, 2021

The name of transform_param_axes_split_info_to_new_shape and new_shape is maybe confusing. It's actually the old shape here. But it's used the other way around. So in your case, it means you start with 2 input layers, and then remove one.

The way the output of transform_param_axes_split_info_to_new_shape is used in copy_with_new_split_axes requires actually that we put the 0 in there.

So what we want:

  assert_equal(transform_param_axes_split_info_to_new_shape([[10025, 500], [10025]], (500, 10025)), [[0, 500], [10025]])

You should add that to test_transform_param_axes_split_info_to_new_shape.

maybe transform_param_axes_split_info_to_new_shape is used elsewhere

Just check that. I think not.

Is it save to allow the case new_parts[0] == 0 ...

Why not? Also, allowing sth which was not allowed before can not possible break anything.

Now the first parts ... hits my newly defined third condition ..., then the second parts ... overwrites it ....

That's way too ugly. You should not depend on assuming a specific order of the axes for your heuristic to work. (Yes, you can already now create strange edge cases where it will break depending on the order, but ignore those. Do not make it explicitly depending on such behavior.)

I would add sth like this to the first loop:

elif new_dim in parts:
  dim_diff[new_dim] = new_dim

And in the second loop, in the sum ... != new_dim if-branch, I would add another check:

if new_dim in new_parts:
  new_parts = ...
else:
  ... (as before)

Also change assert new_parts[j] > 0 to assert new_parts[j] >= 0.

Can we rewrite the whole mechanism to not use a dict dim_diff ...

I'm quite sure this would break other cases.

Maybe instead of dict[int,int], it could be dict[int,set[int]].

However, I also would try to not make this too complicated.

Note that there will always be cases which this does not fully cover, no matter what you do. I think it's fine now to add one or two more heuristics to cover your case but we should not make it too complicated.

Maybe there is a better way how to infer the old_axes_splits or to copy over the parameters.

@albertz
Copy link
Member

albertz commented Jan 20, 2021

Is this resolved?

Note that we also don't really need to invest so much energy into making the heuristic work correct in all cases (which is anyway not possible). We could maybe instead also just invest some energy into making it the right way so that it always correct, in a clean and predictable way, without relying on such a heuristic. This should certainly be possible, right?

E.g. the problem this heuristic tries to solve is, to recover the old shape split information, which is not available at this point anymore. Maybe we could just store this in the checkpoint or somewhere else?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants