-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathhiperparam_tuning_without_cv.py
81 lines (59 loc) · 4 KB
/
hiperparam_tuning_without_cv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!git clone -b Colab https://github.com/AntonioTepsich/Convolutional-KANs.git
#%cd Convolutional-KANs/
#!git pull
import sys
sys.path.insert(1,'Convolutional-KANs')
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from architectures_28x28.KKAN import *
from architectures_28x28.conv_and_kan import NormalConvsKAN,NormalConvsKAN_Medium
from architectures_28x28.KANConvs_MLP import *
from architectures_28x28.KANConvs_MLP_2 import *
from architectures_28x28.SimpleModels import *
from evaluations import *
from hiperparam_tuning import *
torch.manual_seed(42) #Lets set a seed for the weights initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
#Load MNIST and filter by classes
mnist_train = MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)
dataset_name = "MNIST"
path = f"models/{dataset_name}"
if not os.path.exists("models"):
os.mkdir("models")
if not os.path.exists("results"):
os.mkdir("results")
#if not os.path.exists(os.mkdir("/".join(path.split("/")[:-1]))):
#os.mkdir("/".join(path.split("/")[:-1]))
if not os.path.exists(path):
os.mkdir(path)
results_path = os.path.join("results",dataset_name)
if not os.path.exists(results_path):
os.mkdir(results_path)
def train_all_kans(grid_size ):
search_hiperparams_and_get_final_model(KANC_MLP,True, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name, grid_size = grid_size)
search_hiperparams_and_get_final_model(KANC_MLP_Big,True, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name,grid_size = grid_size)
search_hiperparams_and_get_final_model(KANC_MLP_Medium,True, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name, grid_size = grid_size)
search_hiperparams_and_get_final_model(KKAN_Convolutional_Network,True, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name, grid_size = grid_size)
search_hiperparams_and_get_final_model(KKAN_Small,True, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name, grid_size = grid_size)
search_hiperparams_and_get_final_model(NormalConvsKAN,True, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name, grid_size = grid_size)
search_hiperparams_and_get_final_model(NormalConvsKAN_Medium,True, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name,grid_size = grid_size)
train_all_kans(grid_size = 10)
train_all_kans(grid_size = 20)
search_hiperparams_and_get_final_model(SimpleCNN,False, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name)
search_hiperparams_and_get_final_model(MediumCNN,False, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name)
search_hiperparams_and_get_final_model(CNN_Big,False, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name)
search_hiperparams_and_get_final_model(CNN_more_convs,False, mnist_train, test_loader,max_epochs= 20,path = path,search_grid_combinations = 8 ,folds = 1,dataset_name=dataset_name)