Skip to content

Commit

Permalink
Fix batching network inputs of mixed type
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Manhaeve committed Jul 5, 2024
1 parent 04fe7a1 commit 2c4df75
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
33 changes: 18 additions & 15 deletions src/deepproblog/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class Network(object):
"""Wraps a PyTorch neural network for use with DeepProblog"""

def __init__(
self,
network_module: torch.nn.Module,
name: str,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler=None,
k: Optional[int] = None,
batching: bool = False,
self,
network_module: torch.nn.Module,
name: str,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler=None,
k: Optional[int] = None,
batching: bool = False,
):
"""Create a Network object
Expand Down Expand Up @@ -121,13 +121,17 @@ def __call__(self, to_evaluate: list) -> list:
:return:
"""
if self.batching:
batched_inputs: List[torch.Tensor] = [
self.function(*e)[0] for e in to_evaluate
]
stacked_inputs = torch.stack(batched_inputs)
if self.is_cuda:
stacked_inputs = stacked_inputs.cuda(device=self.device)
evaluated = self.network_module(stacked_inputs)
inputs = (self.function(*e) for e in to_evaluate)
stacked_inputs = list()
for inputs in zip(*inputs):
try:
inputs = torch.stack(inputs)
if self.is_cuda:
inputs.cuda(device=self.device)
except TypeError:
inputs = list(inputs)
stacked_inputs.append(inputs)
evaluated = self.network_module(*stacked_inputs)
else:
evaluated = [self.network_module(*self.function(*e)) for e in to_evaluate]
return evaluated
Expand Down Expand Up @@ -169,7 +173,6 @@ def get_hyperparameters(self):
}
return parameters


# class NetworkEvaluation(object):
# """
# An object that keeps track of which inputs the neural networks need to be evaluated on.
Expand Down
35 changes: 22 additions & 13 deletions src/deepproblog/tests/test_neural_predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
nn(dummy1,[X],Y,[a,b,c]) :: net1(X,Y).
nn(dummy2,[X]) :: net2(X).
nn(dummy3,[X],Y) :: net3(X,Y).
nn(dummy4,[X,Y],Z,[a,b]) :: net4(X,Y,Z).
nn(dummy4,[X,Y]) :: net4(X,Y).
test1(X1,Y1,X2,Y2) :- net1(X1,Y1), net1(X2,Y2).
test2(X1,X2) :- net2(X1), net2(X2).
Expand All @@ -28,9 +28,19 @@
dummy_values3 = {Term("i1"): [1.0, 2.0, 3.0, 4.0], Term("i2"): [-1.0, 0.0, 1.0]}
dummy_net3 = Network(DummyNet(dummy_values3), "dummy3")

dummy_net4 = Network(DummyTensorNet(batching=True), "dummy4", batching=True)

tensors = {(Constant(0),): torch.Tensor([0.2]), (Constant(1),): torch.Tensor([0.8])}
dummy_tensors = {(Term("a"),): torch.Tensor([0.1, 0.2, 0.3, 0.4]), (Term("b"),): torch.Tensor([0.25, 0.25, 0.25, 0.25])}


class IndexNet(torch.nn.Module):

def forward(self, t, index):
# index = int(index)
index = torch.LongTensor([int(i) for i in index])
return t.index_select(dim=1, index=index)


dummy_net4 = Network(IndexNet(), "dummy4", batching=True)


@pytest.fixture(
Expand All @@ -53,7 +63,7 @@ def model(request) -> Model:
model = Model(program, [dummy_net1, dummy_net2, dummy_net3, dummy_net4], load=False)
engine = request.param["engine_factory"](model)
model.set_engine(engine, cache=request.param["cache"])
model.add_tensor_source('dummy', tensors)
model.add_tensor_source('dummy', dummy_tensors)
return model


Expand Down Expand Up @@ -108,13 +118,12 @@ def test_det_network_substitution(model: Model):
assert all(r1.detach().numpy() == [1.0, 2.0, 3.0, 4.0])
assert all(r2.detach().numpy() == [-1.0, 0.0, 1.0])

def test_double_input(model: Model):
terms = lambda x: Term("net4",
Term("tensor",Term("dummy", Constant(0))),
Term("tensor",Term("dummy", Constant(1))),
x)
results = model.solve([Query(terms(Var("X")))])
r1 = float(results[0].result[terms(Term("a"))])
r2 = float(results[0].result[terms(Term("b"))])
def test_multi_input_network(model: Model):
dummy_tensor = lambda x: Term("tensor", Term("dummy", x))
q1 = Query(Term("net4", dummy_tensor(Term("a")), Constant(1)))
q2 = Query(Term("net4", dummy_tensor(Term("b")), Constant(2)))
results = model.solve([q1, q2])
r1 = float(results[0].result[q1.query])
r2 = float(results[1].result[q2.query])
assert pytest.approx(0.2) == r1
assert pytest.approx(0.8) == r2
assert pytest.approx(0.25) == r2

0 comments on commit 2c4df75

Please sign in to comment.