You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "muzero.py", line 36, in <module>
muzero(config)
File "muzero.py", line 24, in muzero
train_network(config, storage, replay_buffer, config.nb_epochs)
File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 19, in train_network
update_weights(optimizer, network, batch)
File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 84, in update_weights
optimizer.minimize(loss=loss, var_list=network.cb_get_variables())
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\keras\optimizer_v2\optimizer_v2.py", line 317, in minimize
loss, var_list=var_list, grad_loss=grad_loss)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\keras\optimizer_v2\optimizer_v2.py", line 351, in _compute_gradients
loss_value = loss()
File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 42, in loss
loss += tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size))
File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 103, in loss_value
return tf.nn.softmax_cross_entropy_with_logits(logits=value_batch, labels=targets)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 3105, in softmax_cross_entropy_with_logits_v2
labels=labels, logits=logits, axis=axis, name=name)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 3206, in softmax_cross_entropy_with_logits_v2_helper
precise_logits, labels, name=name)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\gen_nn_ops.py", line 11458, in softmax_cross_entropy_with_logits
_six.raise_from(_core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must be broadcastable: logits_size=[110592,24] labels_size=[2048,24] [Op:SoftmaxCrossEntropyWithLogits] name: softmax_cross_entropy_with_logits/
The text was updated successfully, but these errors were encountered:
I am trying to train Go using this wonderful project and GymGo.
However, training.py prints an error at line 103.
https://github.com/intenseG/MuZero/blob/5cab8b6c89b652ad33e277318bfd4389de1315a9/muzero/training/training.py#L103
Is there a solution to this error?
Thank you.
[Additional debugging information]
value_batch.shape
: (2048, 6, 9, 24)targets.shape
: (2048, 24)[Stack trace]
The text was updated successfully, but these errors were encountered: