-
Notifications
You must be signed in to change notification settings - Fork 0
/
documentation
64 lines (49 loc) · 2.37 KB
/
documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
step 1 :
Data gathering .
the file image_downloader.py has the code used to download images from bing website .
step 2:
data information and its labels .
the file image_downloader has the code for in which it lists the contents of the folder
and inputs them in the csv file of each image folder .
step 3 : preparation of training set
So of the entire data . 70% of the data was used to train the model . the other 20% was used to validate
the next 10 % was used to test .
After which data augmentation was on done on the training image where each photo accounted for 6 images in the final training data
Model .
there are 4 models available here .
1.Build a CNN model from scratch-->cnn.py
2.Use pre-trained model (resnet-50) + finetune only the last layer-->pretrain_last_layer.py
fc_inputs=resnet50.fc.in_features
resnet50.fc = nn.Sequential(
# nn.Linear(fc_inputs, 256),
# nn.ReLU(),
# nn.Dropout(0.4),
# nn.Linear(256, 3), # Since 10 possible outputs
# nn.LogSoftmax(dim=1) # For using NLLLoss()
nn.Dropout(0.4),
nn.Linear(fc_inputs, 3),
nn.LogSoftmax(dim=1)
used to retrain just the last layer
3.Use pre-trained model + finetune last few layers-->PreTrain_few.py
model downloaded from torch vision
resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
i=0
for params in resnet50.parameters():
params.requires_grad=False
unfreeze_layers = ["layer4", "fc"]#these are last two layers of the resnet50 model
for name, param in resnet50.named_parameters():
if any(layer_name in name for layer_name in unfreeze_layers):
param.requires_grad = True
in_features=resnet50.fc.in_features
out_features=3
resnet50.fc=nn.Linear(in_features,out_features)
the above represents the code used to the just allow train of the last 2 layers
4.Build ViT model--> Vit.py
model downloaded from torchvison
files :
albumentations.py
used the albumentations library for data augmentation
main.py is just a file to be run if you you want to upload an image through a web server through an api .
uvicorn main:app --reload
where you can upload the file and wait for the probabilities being displayed .
Its a very simple demo of api integration . not a very sophisticated one .