Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix code #369

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions custom_u2net_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB


def test_model(model):


# ------- 1. set the directory of test dataset --------

test_data_dir = os.path.join(os.getcwd(), 'my_data' + os.sep)
test_image_dir = os.path.join('TDP_test_dataset','TDP_IMAGES' + os.sep)
test_label_dir = os.path.join('TDP_test_dataset','TDP_MASKS' + os.sep)



image_ext = '.jpg'
label_ext = '.png'

batch_size_val = 1

test_img_name_list = glob.glob(test_data_dir + test_image_dir + '*' + image_ext)

test_lbl_name_list = []
for img_path in test_img_name_list:
img_name = img_path.split(os.sep)[-1]

aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]

test_lbl_name_list.append(test_data_dir + test_label_dir + imidx + label_ext)

test_salobj_dataset = SalObjDataset(
img_name_list=test_img_name_list,
lbl_name_list=test_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
ToTensorLab(flag=0)]))
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=batch_size_val,
shuffle=False,
num_workers=1)

# --------- 2. test process ---------
if torch.cuda.is_available():
model.cuda()
model.eval()
total_pixels = 0
correct_pixels = 0
accuracy = 0

with torch.no_grad():
for i, data in enumerate(test_salobj_dataloader):
inputs, labels = data['image'], data['label']

inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)

if torch.cuda.is_available():
inputs, labels = inputs.cuda(), labels.cuda()

outputs = model(inputs)
predicted_masks = (outputs[0] > 0.5).float()

total_pixels = labels.numel()
correct_pixels = (predicted_masks == labels).sum().item()
accuracy += (correct_pixels/total_pixels)*100

avr_accuracy=accuracy/len(test_salobj_dataloader)
print(f'avarage_accuracy: {avr_accuracy}%')


if __name__ == "__main__":
pass
199 changes: 199 additions & 0 deletions custom_u2net_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms

import numpy as np
import glob
import os

from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET
from model import U2NETP
from tqdm import tqdm
from model_processing.prepare_model import get_latest_model, get_latest_version
from model_processing.convert_model import convert_model_to_onnx
from custom_u2net_test import test_model
from model_processing.upload_model_to_S3bucket import upload_folder_to_s3



def main():
# ------- 1. define loss function --------

bce_loss = nn.BCELoss(size_average=True)

def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
# Convert all tensors to torch.float32
d0, d1, d2, d3, d4, d5, d6, labels_v = (
d0.float(), d1.float(), d2.float(), d3.float(), d4.float(), d5.float(), d6.float(), labels_v.float()
)

loss0 = bce_loss(d0, labels_v)
loss1 = bce_loss(d1, labels_v)
loss2 = bce_loss(d2, labels_v)
loss3 = bce_loss(d3, labels_v)
loss4 = bce_loss(d4, labels_v)
loss5 = bce_loss(d5, labels_v)
loss6 = bce_loss(d6, labels_v)

loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
loss6.data.item()))

return loss0, loss


# ------- 2. set the directory of training dataset --------

model_name = 'u2net' #'u2netp'

data_dir = os.path.join(os.getcwd(), 'my_data' + os.sep)
tra_image_dir = os.path.join('TDP_train_dataset','TDP_IMAGES' + os.sep)
tra_label_dir = os.path.join('TDP_train_dataset','TDP_MASKS' + os.sep)

image_ext = '.jpg'
label_ext = '.png'

model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)

epoch_num = 100
batch_size_train = 32
train_num = 0

tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)

tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split(os.sep)[-1]

aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]

tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)

print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")

train_num = len(tra_img_name_list)

salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
RandomCrop(288),
ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)

# ------- 3. define model --------
# define the net
if(model_name=='u2net'):
net = U2NET(3, 1)
elif(model_name=='u2netp'):
net = U2NETP(3,1)

if torch.cuda.is_available():
net.cuda()

# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
# save_frq = 2000 # save the model every 2000 iterations
# Check if there is a pre-trained model to load
pretrained_model_path = get_latest_model("saved_models/u2net")

if os.path.exists(pretrained_model_path):
# Load the pre-trained model
net.load_state_dict(torch.load(pretrained_model_path))
print(f"Pre-trained model loaded from {pretrained_model_path}")
else:
print("No pre-trained model found. Training from scratch.")

for epoch in tqdm(range(0, epoch_num)):
net.train()

for i, data in tqdm(enumerate(salobj_dataloader)):
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1

inputs, labels = data['image'], data['label']

inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)

# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

# y zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

loss.backward()
optimizer.step()

# # print statistics
running_loss += loss.data.item()
running_tar_loss += loss2.data.item()

# del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss

print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))

#test_model
test_model(net)

# Save model
latest_version = get_latest_version(pretrained_model_path)
if latest_version.isdigit():
model_name = f"u2net_version_{int(latest_version)+1}.pth"
else:
model_name = f"u2net_version_1.pth"
torch.save(net.state_dict(), model_dir + model_name)
print(f"Final model saved as {model_name}")

#convert model to onnx
latest_model= os.path.join(model_dir,model_name)
convert_model_to_onnx(latest_model)

#upload model to AWS S3 bucket
upload_folder_to_s3(model_dir,'tdp-model')

if __name__ == '__main__':
main()

30 changes: 30 additions & 0 deletions model_processing/convert_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import io
import torch.onnx
from model import U2NET
import os
from model_processing.prepare_model import get_latest_version

def convert_model_to_onnx(model_path):
torch_model = U2NET(3,1)
batch_size = 1

torch_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
torch_model.eval()

x = torch.randn(batch_size, 3, 512, 512, requires_grad=True)
last_character = get_latest_version(model_path)
model_dir = "saved_models/ABR_model"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
torch.onnx.export(torch_model, x,os.path.join(model_dir,f"ARB_version_{int(last_character)}.onnx"),
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes = {'input' : {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)


if __name__ == '__main__':
pass
16 changes: 16 additions & 0 deletions model_processing/prepare_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os

def get_latest_model(model_folder):
model_list =[file for file in os.listdir(model_folder) if file.startswith("u2net_version_")]
if not model_list:
latest_model = 'u2net.pth'
else:
sorted_model_files = sorted(model_list, key=lambda x: int(x[len("u2net_version_"):-len(".pth")]))
latest_model = sorted_model_files[-1]
return os.path.join(model_folder,latest_model)


def get_latest_version(model_path):
model_name = model_path.split("/")[-1].split(".")[0]
latest_version = model_name[-1]
return latest_version
27 changes: 27 additions & 0 deletions model_processing/upload_model_to_S3bucket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
from dotenv import load_dotenv
import boto3

# Load environment variables from .env
load_dotenv()

def upload_folder_to_s3(local_folder,bucket_name):
# Retrieve AWS credentials from environment variables
aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
aws_region = os.environ.get('AWS_REGION')

# Create an S3 client
s3 = boto3.client('s3', aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region)

# Iterate through each file in the local folder
for root, dirs, files in os.walk(local_folder):
for file in files:
local_file_path = os.path.join(root, file)
s3_key = os.path.relpath(local_file_path, local_folder).replace("\\", "/")

try:
# Upload the file to S3
s3.upload_file(local_file_path, bucket_name, s3_key)
except Exception as e:
print(f"Error uploading {local_file_path}: {e}")
Loading