-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep_structure.py
38 lines (29 loc) · 1004 Bytes
/
deep_structure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
This file contains the code similar to the implementation of the code in paper Deep Image Prior
"""
import torch
from models.PointNetFCAE import *
from models.MLP import *
from utils.train_utils import chamfer
from multiprocessing import Queue
from utils.parse_args import parse_args
from utils.train_utils import create_optimizer, train
from data_manager.shapenet import ShapenetDataProcess
from data_manager.data_process import kill_data_processes
epoch = 200
args = parse_args()
# args.model = PointNetFCAE_create_model(args)
args.model = MLP()
data_processes = []
data_queue = Queue(1)
for i in range(args.nworkers):
data_processes.append(ShapenetDataProcess(data_queue, args, split='train', repeat=False))
data_processes[-1].start()
# args.error = torch.nn.MSELoss()
# args.error = ChamferLoss()
args.optimizer = create_optimizer(args, args.model)
i = 0
while i != epoch:
train(args, data_queue, data_processes, i)
i += 1
kill_data_processes(data_queue, data_processes)