Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About training of GoGame using GymGo #1

Open
intenseG opened this issue Jan 27, 2020 · 0 comments
Open

About training of GoGame using GymGo #1

intenseG opened this issue Jan 27, 2020 · 0 comments

Comments

@intenseG
Copy link

I am trying to train Go using this wonderful project and GymGo.
However, training.py prints an error at line 103.
https://github.com/intenseG/MuZero/blob/5cab8b6c89b652ad33e277318bfd4389de1315a9/muzero/training/training.py#L103

Is there a solution to this error?
Thank you.

[Additional debugging information]

value_batch.shape : (2048, 6, 9, 24)
targets.shape : (2048, 24)

[Stack trace]

Traceback (most recent call last):
  File "muzero.py", line 36, in <module>
    muzero(config)
  File "muzero.py", line 24, in muzero
    train_network(config, storage, replay_buffer, config.nb_epochs)
  File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 19, in train_network
    update_weights(optimizer, network, batch)
  File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 84, in update_weights
    optimizer.minimize(loss=loss, var_list=network.cb_get_variables())
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\keras\optimizer_v2\optimizer_v2.py", line 317, in minimize
    loss, var_list=var_list, grad_loss=grad_loss)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\keras\optimizer_v2\optimizer_v2.py", line 351, in _compute_gradients
    loss_value = loss()
  File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 42, in loss
    loss += tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size))
  File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 103, in loss_value
    return tf.nn.softmax_cross_entropy_with_logits(logits=value_batch, labels=targets)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 3105, in softmax_cross_entropy_with_logits_v2
    labels=labels, logits=logits, axis=axis, name=name)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 3206, in softmax_cross_entropy_with_logits_v2_helper
    precise_logits, labels, name=name)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\gen_nn_ops.py", line 11458, in softmax_cross_entropy_with_logits
    _six.raise_from(_core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must be broadcastable: logits_size=[110592,24] labels_size=[2048,24] [Op:SoftmaxCrossEntropyWithLogits] name: softmax_cross_entropy_with_logits/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant