Skip to content

Commit

Permalink
Test ReuseParams with different variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
Zettelkasten committed Feb 19, 2021
1 parent d664175 commit bf0adbd
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,6 +3059,28 @@ def make_feed_dict(seq_len=10):
session.run(network.get_default_output_layer().output.placeholder, feed_dict=feed)


def test_ReuseParams_different_names():
n_batch, n_time, n_total, n_heads = 7, 3, 40, 2
assert n_total % n_heads == 0
config = Config({
"extern_data": {"data": {"dim": n_total}},
"debug_print_layer_output_template": True,
})
with make_scope():
net = TFNetwork(config=config)

def custom(reuse_layer, *args, **kwargs):
return reuse_layer.params['QKV']

net.construct_from_dict({
"self_att": {"class": "self_attention", "num_heads": n_heads, "total_key_dim": n_total, "n_out": n_total},
"linear": {"class": "linear", "n_out": n_total * 3, "activation": None, "with_bias": False,
"reuse_params": {
"auto_create_missing": False, # should not matter as we do not have any bias
"map": {"W": {"reuse_layer": "self_att", "custom": custom}}}},
"output": {"class": "copy", "from": "linear"}})


def test_LossAsIs_custom_dim():
config = Config()
config.update({
Expand Down

0 comments on commit bf0adbd

Please sign in to comment.