Skip to content

Commit

Permalink
Merge pull request #1 from andrewk1/andrew/examples
Browse files Browse the repository at this point in the history
Andrew/examples
  • Loading branch information
andrewk1 authored Jan 14, 2022
2 parents 7f3b8ca + 18fc501 commit 178b2e8
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 80 deletions.
5 changes: 3 additions & 2 deletions cands/cands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
Correct&Smooth Implementation in torch
"""
import torch
import torch.nn.functional as F

from functools import lru_cache
from tqdm import tqdm

from .utils import edge_index_to_sparse, to_dense_adj

from torch_geometric.utils import to_dense_adj

@lru_cache(maxsize=None)
def normalize_adj_matrix(edge_index):
Expand Down Expand Up @@ -81,6 +81,7 @@ def correct_and_smooth(y, yhat,
"""
c&s full pipeline
"""
y = F.one_hot(y, max(y) + 1)
train_split_idxs = [ ix for ix in range(len(y)) if ix not in val_split_idxs ]
print_if_verbose("Normalizing Adj Matrix...", verbose)
S = normalize_adj_matrix(edge_index)
Expand Down
77 changes: 0 additions & 77 deletions cands/utils.py

This file was deleted.

73 changes: 73 additions & 0 deletions examples/lastfm/lastfm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Using cands to implement and evaluate C&S on AsiaFM Dataset
Additional dependencies:
- torch-geometric
"""

import numpy as np
import torch; torch.manual_seed(42);
import torch.nn.functional as F

from cands import correct_and_smooth
from torch_geometric.datasets import Twitch, LastFMAsia
from tqdm import tqdm

from model import BasePredictor

graph = LastFMAsia("./data") # Twitch("./data", "EN")
X = graph.data.x
edge_index = graph.data.edge_index
y = graph.data.y.squeeze()

NUM_EDGE_KEEP = 0.3

kept_edges = np.random.choice(edge_index.shape[1], int(edge_index.shape[1] * NUM_EDGE_KEEP))
edge_index = edge_index[:, kept_edges]
kept_nodes = list(set(edge_index.flatten().tolist()))

# We use a 0.5 / 0.25 / 0.25 train/val/test split
SPLIT_FRACTIONS = (0.8, 0.1, 0.1)
splits_sizes = (int(SPLIT_FRACTIONS[0] * len(X[kept_nodes])),
int(SPLIT_FRACTIONS[1] * len(X[kept_nodes])),
len(X[kept_nodes]) - int(SPLIT_FRACTIONS[0] * len(X[kept_nodes]))- int(SPLIT_FRACTIONS[1] * len(X[kept_nodes])))
train_split, val_split, test_split = splits = torch.utils.data.random_split(X[kept_nodes], splits_sizes)
(X_train, y_train), (X_val, y_val), (X_test, y_test) = [(X[split.indices], y[split.indices]) for split in splits]

num_labels = int(max(y) + 1)
print(f"Dataset: { X_train.shape[0] } training, { X_val.shape[0] } val, { X_test.shape[0] } test samples with { X.shape[1] } dim embeddings")
print(f"{ edge_index.shape[1] } total followerships (edges)")
print(f"{ num_labels } total classes")

net = BasePredictor(in_size=X.shape[1], n_hidden_layers=1, out_size=num_labels)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())

def train(X, y):
optimizer.zero_grad()
yhat = net(X)
l = loss(yhat, y)
l.backward()
optimizer.step()
return l

NUM_EPOCHS = 500

pbar = tqdm(range(NUM_EPOCHS))
for ep in pbar:
l = train(X_train, y_train)
pred = torch.argmax(net(X_val), -1)
pbar.set_postfix({'loss': float(l), "val_acc": float(torch.sum(pred == y_val) / len(pred))})

yhat = torch.softmax(net(X), -1)
val_split_idxs = val_split.indices + test_split.indices
yhat_cands = correct_and_smooth(y, yhat, edge_index, val_split_idxs)

yhat_mlp = torch.argmax(net(X), -1)
print(f"Val accuracy MLP: { torch.mean((yhat_mlp[val_split.indices] == y[val_split.indices]).type(torch.float32)) }")
print(f"Test accuracy MLP: { torch.mean((yhat_mlp[test_split.indices] == y[test_split.indices]).type(torch.float32)) }\n")

yhat_cands = torch.argmax(yhat_cands, -1)
print(f"Val accuracy CandS: { torch.mean((yhat_cands[val_split.indices] == y[val_split.indices]).type(torch.float32)) }")
print(f"Test accuracy CandS: { torch.mean((yhat_cands[test_split.indices] == y[test_split.indices]).type(torch.float32)) }\n")

22 changes: 22 additions & 0 deletions examples/lastfm/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
The first step is to acquire a base classifier model that can output a probability distribution over the classes. We train a shallow MLP in PyTorch:
"""
import torch

class BasePredictor(torch.nn.Module):
"""
A simple MLP class to serve as the base predictor
"""
def __init__(self, n_hidden_layers=1, in_size=128, hidden_size=64, out_size=1):
super(BasePredictor, self).__init__()
if n_hidden_layers == 0:
self.net = torch.nn.Linear(in_size, out_size)
else:
net = [torch.nn.Linear(in_size, hidden_size), torch.nn.ReLU()]
net += [torch.nn.Linear(hidden_size, hidden_size), torch.nn.ReLU()] * (n_hidden_layers - 1)
net += [torch.nn.Linear(hidden_size, out_size)]
self.net = torch.nn.Sequential(*net)

def forward(self, X):
out = self.net(X)
return out.squeeze()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
install_requires=[
"torch>=1.6.0",
"tqdm>=4.27",
"numpy"
"numpy",
"torch-geometric"
],
)

0 comments on commit 178b2e8

Please sign in to comment.