forked from FerranAlet/graph_element_networks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathneural_processes.py
27 lines (25 loc) · 894 Bytes
/
neural_processes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import torch
from torch import nn
class NeuralProcesses(nn.Module):
def __init__(self, encoders, decoders):
super(NeuralProcesses, self).__init__()
self.encoders = encoders
self.decoders = decoders
def forward(self, Inp, Q):
'''
Inp: list of input points (X, y_i) of function i
Q: list of queries X for function j
'''
aux = []
#(BS, #inp, feat)
for (inp, enc) in zip(Inp, self.encoders):
res = (enc(torch.cat((inp[0], inp[1]), dim=-1)))
aux.append(res)
aux = torch.cat(aux, dim=1)
inp_summ = torch.mean(aux, dim=1, keepdim=True) #[BS, 1, feat]
res = []
for (q, dec) in zip(Q, self.decoders):
dec_inp = torch.cat((inp_summ.repeat(1, q.shape[1], 1), q), dim=2)
res.append(dec(dec_inp))
return res