Skip to content

Commit

Permalink
Add PyTorch model to fit/predict
Browse files Browse the repository at this point in the history
  • Loading branch information
ivartb committed Feb 16, 2024
1 parent fa73080 commit a4af998
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 8 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ jobs:
export PATH=bin:$PATH
metafx fit -f wd_unique_pca/feature_table.tsv -i wd_unique_pca/samples_categories.tsv -w wd_fit_rf
metafx fit -f wd_unique_pca/feature_table.tsv -i wd_unique_pca/samples_categories.tsv -w wd_fit_xgb -e XGB
metafx fit -f wd_unique_pca/feature_table.tsv -i wd_unique_pca/samples_categories.tsv -w wd_fit_torch -e Torch
- name: metafx cv
run: |
export PATH=bin:$PATH
Expand Down
24 changes: 21 additions & 3 deletions bin/metafx-scripts/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from sklearn.metrics import classification_report
from xgboost import XGBClassifier
from sklearn import preprocessing
from metafx_torch import TorchLinearModel
import torch

if __name__ == "__main__":
features = pd.read_csv(sys.argv[1], header=0, index_col=0, sep="\t")
Expand All @@ -23,19 +25,35 @@
M = features.shape[0] # features count
N = features.shape[1] # samples count

model = RandomForestClassifier(n_estimators=100) if sys.argv[4] == "RF" else XGBClassifier(n_estimators=100)
X = features.T
y = np.array([metadata.loc[i, 1] for i in X.index])

model = None
if sys.argv[4] == "RF":
model = RandomForestClassifier(n_estimators=100)
elif sys.argv[4] == "XGB":
model = XGBClassifier(n_estimators=100)
else:
model = TorchLinearModel(n_features=M, n_classes=len(set(y)))

if sys.argv[4] == "XGB":
le = preprocessing.LabelEncoder()
le.fit(y)
y = le.transform(y)
elif sys.argv[4] == "Torch":
le = preprocessing.LabelEncoder()
le.fit(y)
y = le.transform(y)

model.fit(X, y)
dump(model, outName + ".joblib")

if sys.argv[4] == "XGB":
if sys.argv[4] == "RF":
dump(model, outName + ".joblib")
elif sys.argv[4] == "XGB":
dump(model, outName + ".joblib")
dump(le, outName + "_le.joblib")
elif sys.argv[4] == "Torch":
torch.save(model, outName + ".joblib")
dump(le, outName + "_le.joblib")

print("Model accuracy after training:")
Expand Down
48 changes: 48 additions & 0 deletions bin/metafx-scripts/metafx_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python
# PyTorch Liner Classification Model
import torch
from torch import nn, optim
import numpy as np


class TorchLinearModel():
"""PyTorch sequential linear model for classification into C classes"""

def __init__(self, n_features, n_classes, n_epochs=1000):
self.n_features = n_features
self.n_classes = n_classes
self.n_epochs = n_epochs
self.model = nn.Sequential(
nn.Linear(self.n_features, 32),
nn.Sigmoid(),
nn.Linear(32, self.n_classes),
nn.Sigmoid()
)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def fit(self, X, y):
y_true = np.zeros((X.shape[0], self.n_classes))
for i, val in enumerate(y):
y_true[i, val] = 1.

X = torch.from_numpy(X.values).float()
y_true = torch.from_numpy(y_true)

for epoch in range(self.n_epochs):
self.optimizer.zero_grad()

y_pred = self.model(X)
loss = self.criterion(y_pred, y_true)
loss.backward()
self.optimizer.step()

if (epoch+1) % 100 == 0:
print("Epoch", epoch+1, "/", self.n_epochs, ":", round(loss.item(), 5), "loss", flush=True)

def predict(self, X):
y_pred = self.model(torch.from_numpy(X.values).float()).cpu().data.numpy()
return np.argmax(y_pred, axis=1)

def get_model(self):
return self.model
16 changes: 12 additions & 4 deletions bin/metafx-scripts/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@
import pandas as pd
from joblib import load
from sklearn.metrics import classification_report
import torch


if __name__ == "__main__":
features = pd.read_csv(sys.argv[1], header=0, index_col=0, sep="\t")
outName = sys.argv[2]
model = load(sys.argv[3])
metadata = None
model_type = sys.argv[4]
if model_type == "XGB":

if model_type == "RF":
model = load(sys.argv[3])
elif model_type == "XGB":
model = load(sys.argv[3])
le = load(sys.argv[3][:-7] + "_le.joblib")
elif model_type == "Torch":
model = torch.load(sys.argv[3])
le = load(sys.argv[3][:-7] + "_le.joblib")

metadata = None
if len(sys.argv) == 6:
metadata = pd.read_csv(sys.argv[5], sep="\t", header=None, index_col=0, dtype=str)
metadata.index = metadata.index.astype(str)
Expand All @@ -24,7 +32,7 @@
X = features.T
y_pred = model.predict(X)

if model_type == "XGB":
if model_type == "XGB" or model_type == "Torch":
y_pred = le.inverse_transform(y_pred)

outFile = open(outName + ".tsv", "w")
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ scikit-learn==1.3.0
matplotlib==3.8.2
joblib==1.2.0
ete3==3.1.3
xgboost==2.0.3
xgboost==2.0.3
torch==2.2.0

0 comments on commit a4af998

Please sign in to comment.