You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, calling B.jit on a logpdf-based objective works for a model with a single context set and single target set. However, a ValueError is raised in the case where a model takes multiple context sets as input. See MWE on Google Colab and below:
import neuralprocesses.tensorflow as nps
import tensorflow as tf
import lab.tensorflow as B
import time
def test_jit(n_context_sets=1):
model = nps.construct_convgnp(dim_x=1, dim_yc=(1,)*n_context_sets, dim_yt=1)
def objective(xt, yt, *context_data):
"""
Context data to be passed as xc1, yc1, xc2, yc2, ...
"""
# Convert to list of (x, y) tuples format
context_data = [(context_data[2*i], context_data[2*i+1]) for i in range(n_context_sets)]
return -model(context_data, xt).logpdf(yt)
def test(objective):
"""Generate random data to test objective"""
xcs = [B.randn(tf.float32, 16, 1, 10) for i in range(n_context_sets)]
ycs = [B.randn(tf.float32, 16, 1, 10) for i in range(n_context_sets)]
context_data = []
for i in range(n_context_sets):
context_data.append(xcs[i])
context_data.append(ycs[i])
xt = B.randn(tf.float32, 16, 1, 20)
yt = B.randn(tf.float32, 16, 1, 20)
return objective(xt, yt, *context_data)
its = 10
s = time.time()
for _ in range(its):
test(objective)
print(f"Without JIT ({n_context_sets} context sets):", (time.time() - s) / its)
objective_compiled = B.jit(objective)
test(objective_compiled) # Run once to compile.
s = time.time()
for _ in range(its):
test(objective_compiled)
print(f"With JIT ({n_context_sets} context sets):", (time.time() - s) / its)
test_jit(n_context_sets=1)
test_jit(n_context_sets=2)
Running the above produces:
Without JIT (1 context sets): 0.27810795307159425
With JIT (1 context sets): 0.027799010276794434
Without JIT (2 context sets): 0.2577114820480347
However, at the point of running the model with two context sets with JIT, it raises a ValueError:
Currently, calling
B.jit
on alogpdf
-based objective works for a model with a single context set and single target set. However, aValueError
is raised in the case where a model takes multiple context sets as input. See MWE on Google Colab and below:Running the above produces:
However, at the point of running the model with two context sets with JIT, it raises a
ValueError
:The text was updated successfully, but these errors were encountered: