You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for sharing this code. An error I am facing while running the code says the inputs of numpy multinomial function are not correct:
File "mtrand.pyx", line 4639, in mtrand.RandomState.multinomial
ValueError: object too deep for desired array
It seems that the distribution variable in softmax_sample has a shape of [7,2], where its [:,0] values are the visit counts and its [:,1] vaules are action indexes.
I tried a simple fix of adjusting distribution in the softmax_sample function:
def softmax_sample(distribution, temperature: float):
if temperature == 0:
temperature = 1
distribution = numpy.array(distribution)**(1/temperature)
distribution = distribution[:,0]
p_sum = distribution.sum()
sample_temp = distribution/p_sum
return 0, numpy.argmax(numpy.random.multinomial(1, sample_temp, 1))
Is that a correct way to solve the issue? Is the problem stemmed from different numpy version (I am using 1.17.1)?
The text was updated successfully, but these errors were encountered:
Thanks for sharing this code. An error I am facing while running the code says the inputs of numpy multinomial function are not correct:
File "mtrand.pyx", line 4639, in mtrand.RandomState.multinomial
ValueError: object too deep for desired array
It seems that the distribution variable in softmax_sample has a shape of [7,2], where its [:,0] values are the visit counts and its [:,1] vaules are action indexes.
I tried a simple fix of adjusting distribution in the softmax_sample function:
def softmax_sample(distribution, temperature: float):
if temperature == 0:
temperature = 1
distribution = numpy.array(distribution)**(1/temperature)
distribution = distribution[:,0]
p_sum = distribution.sum()
sample_temp = distribution/p_sum
return 0, numpy.argmax(numpy.random.multinomial(1, sample_temp, 1))
Is that a correct way to solve the issue? Is the problem stemmed from different numpy version (I am using 1.17.1)?
The text was updated successfully, but these errors were encountered: