diff --git a/torchstudio/datasetanalyze.py b/torchstudio/datasetanalyze.py
index 929bf1a..6b9239e 100644
--- a/torchstudio/datasetanalyze.py
+++ b/torchstudio/datasetanalyze.py
@@ -1,117 +1,117 @@
-import sys
-
-import torchstudio.tcpcodec as tc
-from torchstudio.modules import safe_exec
-import os
-import io
-from collections.abc import Iterable
-from tqdm.auto import tqdm
-import pickle
-
-original_path=sys.path
-
-app_socket = tc.connect()
-print("Analyze script connected\n", file=sys.stderr)
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
-
- if msg_type == 'SetAnalyzerCode':
- print("Setting analyzer code...\n", file=sys.stderr)
- analyzer = None
- analyzer_code = tc.decode_strings(msg_data)[0]
- error_msg, analyzer_env = safe_exec(analyzer_code, description='analyzer definition')
- if error_msg is not None or 'analyzer' not in analyzer_env:
- print("Unknown analyzer definition error" if error_msg is None else error_msg, file=sys.stderr)
-
- if msg_type == 'StartAnalysisServer' and 'analyzer' in analyzer_env:
- print("Analyzing...\n", file=sys.stderr)
-
- analysis_server, address = tc.generate_server()
-
- if analyzer_env['analyzer'].train is None:
- request_msg='AnalysisServerRequestingAllSamples'
- elif analyzer_env['analyzer'].train==True:
- request_msg='AnalysisServerRequestingTrainingSamples'
- elif analyzer_env['analyzer'].train==False:
- request_msg='AnalysisServerRequestingValidationSamples'
- tc.send_msg(app_socket, request_msg, tc.encode_strings(address))
- dataset_socket=tc.start_server(analysis_server)
-
- while True:
- dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)
-
- if dataset_msg_type == 'NumSamples':
- num_samples=tc.decode_ints(dataset_msg_data)[0]
- pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
-
- if dataset_msg_type == 'InputTensorsID':
- input_tensors_id=tc.decode_ints(dataset_msg_data)
-
- if dataset_msg_type == 'OutputTensorsID':
- output_tensors_id=tc.decode_ints(dataset_msg_data)
-
- if dataset_msg_type == 'Labels':
- labels=tc.decode_strings(dataset_msg_data)
-
- if dataset_msg_type == 'StartSending':
- error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition')
- if error_msg is not None:
- pbar.close()
- print(error_msg, file=sys.stderr)
- dataset_socket.close()
- analysis_server.close()
- break
-
- if dataset_msg_type == 'TrainingSample':
- pbar.update(1)
- error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), True), description='analyzer definition')
- if error_msg is not None:
- pbar.close()
- print(error_msg, file=sys.stderr)
- dataset_socket.close()
- analysis_server.close()
- break
-
- if dataset_msg_type == 'ValidationSample':
- pbar.update(1)
- error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), False), description='analyzer definition')
- if error_msg is not None:
- pbar.close()
- print(error_msg, file=sys.stderr)
- dataset_socket.close()
- analysis_server.close()
- break
-
- if dataset_msg_type == 'DoneSending':
- pbar.close()
- error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition')
- tc.send_msg(dataset_socket, 'DoneReceiving')
- dataset_socket.close()
- analysis_server.close()
- if error_msg is not None:
- print(error_msg, file=sys.stderr)
- else:
- buffer=io.BytesIO()
- pickle.dump(analyzer_env['analyzer'].state_dict(), buffer)
- tc.send_msg(app_socket, 'AnalyzerState',buffer.getvalue())
- tc.send_msg(app_socket, 'AnalysisWeights',tc.encode_floats(analyzer_env['analyzer'].weights))
- print("Analysis complete")
- break
-
- if msg_type == 'LoadAnalyzerState':
- if 'analyzer' in analyzer_env:
- buffer=io.BytesIO(msg_data)
- analyzer_env['analyzer'].load_state_dict(pickle.load(buffer))
- print("Analyzer state loaded")
-
- if msg_type == 'RequestAnalysisReport':
- resolution = tc.decode_ints(msg_data)
- if 'analyzer' in analyzer_env:
- error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition')
- if error_msg is not None:
- print(error_msg, file=sys.stderr)
- if return_value is not None:
- tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value))
-
- if msg_type == 'Exit':
- break
+import sys
+
+import torchstudio.tcpcodec as tc
+from torchstudio.modules import safe_exec
+import os
+import io
+from collections.abc import Iterable
+from tqdm.auto import tqdm
+import pickle
+
+original_path=sys.path
+
+app_socket = tc.connect()
+print("Analyze script connected\n", file=sys.stderr)
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+
+ if msg_type == 'SetAnalyzerCode':
+ print("Setting analyzer code...\n", file=sys.stderr)
+ analyzer = None
+ analyzer_code = tc.decode_strings(msg_data)[0]
+ error_msg, analyzer_env = safe_exec(analyzer_code, description='analyzer definition')
+ if error_msg is not None or 'analyzer' not in analyzer_env:
+ print("Unknown analyzer definition error" if error_msg is None else error_msg, file=sys.stderr)
+
+ if msg_type == 'StartAnalysisServer' and 'analyzer' in analyzer_env:
+ print("Analyzing...\n", file=sys.stderr)
+
+ analysis_server, address = tc.generate_server()
+
+ if analyzer_env['analyzer'].train is None:
+ request_msg='AnalysisServerRequestingAllSamples'
+ elif analyzer_env['analyzer'].train==True:
+ request_msg='AnalysisServerRequestingTrainingSamples'
+ elif analyzer_env['analyzer'].train==False:
+ request_msg='AnalysisServerRequestingValidationSamples'
+ tc.send_msg(app_socket, request_msg, tc.encode_strings(address))
+ dataset_socket=tc.start_server(analysis_server)
+
+ while True:
+ dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)
+
+ if dataset_msg_type == 'NumSamples':
+ num_samples=tc.decode_ints(dataset_msg_data)[0]
+ pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
+
+ if dataset_msg_type == 'InputTensorsID':
+ input_tensors_id=tc.decode_ints(dataset_msg_data)
+
+ if dataset_msg_type == 'OutputTensorsID':
+ output_tensors_id=tc.decode_ints(dataset_msg_data)
+
+ if dataset_msg_type == 'Labels':
+ labels=tc.decode_strings(dataset_msg_data)
+
+ if dataset_msg_type == 'StartSending':
+ error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition')
+ if error_msg is not None:
+ pbar.close()
+ print(error_msg, file=sys.stderr)
+ dataset_socket.close()
+ analysis_server.close()
+ break
+
+ if dataset_msg_type == 'TrainingSample':
+ pbar.update(1)
+ error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), True), description='analyzer definition')
+ if error_msg is not None:
+ pbar.close()
+ print(error_msg, file=sys.stderr)
+ dataset_socket.close()
+ analysis_server.close()
+ break
+
+ if dataset_msg_type == 'ValidationSample':
+ pbar.update(1)
+ error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), False), description='analyzer definition')
+ if error_msg is not None:
+ pbar.close()
+ print(error_msg, file=sys.stderr)
+ dataset_socket.close()
+ analysis_server.close()
+ break
+
+ if dataset_msg_type == 'DoneSending':
+ pbar.close()
+ error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition')
+ tc.send_msg(dataset_socket, 'DoneReceiving')
+ dataset_socket.close()
+ analysis_server.close()
+ if error_msg is not None:
+ print(error_msg, file=sys.stderr)
+ else:
+ buffer=io.BytesIO()
+ pickle.dump(analyzer_env['analyzer'].state_dict(), buffer)
+ tc.send_msg(app_socket, 'AnalyzerState',buffer.getvalue())
+ tc.send_msg(app_socket, 'AnalysisWeights',tc.encode_floats(analyzer_env['analyzer'].weights))
+ print("Analysis complete")
+ break
+
+ if msg_type == 'LoadAnalyzerState':
+ if 'analyzer' in analyzer_env:
+ buffer=io.BytesIO(msg_data)
+ analyzer_env['analyzer'].load_state_dict(pickle.load(buffer))
+ print("Analyzer state loaded")
+
+ if msg_type == 'RequestAnalysisReport':
+ resolution = tc.decode_ints(msg_data)
+ if 'analyzer' in analyzer_env:
+ error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition')
+ if error_msg is not None:
+ print(error_msg, file=sys.stderr)
+ if return_value is not None:
+ tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value))
+
+ if msg_type == 'Exit':
+ break
diff --git a/torchstudio/datasetload.py b/torchstudio/datasetload.py
index 6e0845a..9b9f9cc 100644
--- a/torchstudio/datasetload.py
+++ b/torchstudio/datasetload.py
@@ -1,286 +1,286 @@
-#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490
-import os
-os.environ['KMP_DUPLICATE_LIB_OK']='True'
-
-import sys
-print("Loading PyTorch...\n", file=sys.stderr)
-
-import torch
-from torch.utils.data import Dataset
-from torchvision.transforms.functional import to_tensor
-import torchstudio.tcpcodec as tc
-from torchstudio.modules import safe_exec
-import random
-import os
-import io
-import time
-from collections.abc import Iterable
-from tqdm.auto import tqdm
-
-#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error
-import ssl
-ssl._create_default_https_context = ssl._create_unverified_context
-
-meta_dataset = None
-input_tensors_id = []
-output_tensors_id = []
-
-class MetaDataset(Dataset):
- def __init__(self, train, valid=None):
- self.train_dataset=train
- self.valid_dataset=valid
- self.train_count=None
- self.shuffle=0
- self.smp_usage=1.0
- self.training=True
- self.classes=train.classes if hasattr(train,'classes') else []
- self._gen_index()
-
- def _gen_index(self):
- self.index = []
- for i in range(len(self.train_dataset)):
- self.index.append((self.train_dataset,i))
- if self.valid_dataset is not None:
- for i in range(len(self.valid_dataset)):
- self.index.append((self.valid_dataset,i))
- if self.train_count is None:
- self.train_count=len(self.train_dataset)
- if self.valid_dataset is None:
- self.train_count=round(self.train_count*0.8)
-
- if self.shuffle>0:
- #Fisher–Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
- random.seed(0)
- shuffle_count=self.train_count if self.shuffle==1 else len(self.index)
- for sample in range(shuffle_count):
- target_sample=random.randrange(sample,shuffle_count)
- self.index[sample], self.index[target_sample] = self.index[target_sample], self.index[sample]
-
- def set_num_train(self, num):
- self.train_count=min(num,len(self.index))
- self._gen_index()
-
- def set_smp_usage(self, ratio):
- self.smp_usage=min(max(ratio,0.0),1.0)
- self._gen_index()
-
- def set_shuffle(self, mode):
- self.shuffle=mode
- self._gen_index()
-
- def train(self, mode=True):
- self.training=mode
- return self
-
- def valid(self):
- self.training=False
- return self
-
- def __len__(self):
- if self.training==True:
- return round(self.train_count*self.smp_usage)
- else:
- return round((len(self.index)-self.train_count)*self.smp_usage)
-
- def __getitem__(self, id):
- if id<0 or id>=len(self):
- raise IndexError
- if self.training==True:
- sample_ref=self.index[id]
- else:
- sample_ref=self.index[id+self.train_count]
- sample=sample_ref[0][sample_ref[1]]
-
- #convert to list if needed
- if isinstance(sample, Iterable):
- if type(sample) is dict:
- sample=list(sample.values())
- else:
- sample=list(sample)
- else:
- sample=[sample]
-
- #convert each element of the list to a tensor if needed
- sample_tensors=[]
- for i in range(len(sample)):
- if type(sample[i]) is not torch.Tensor:
- if 'PIL' in str(type(sample[i])) or 'numpy' in str(type(sample[i])):
- sample_tensors.append(to_tensor(sample[i]))
- else:
- try:
- sample_tensors.append(torch.tensor(sample[i]))
- except:
- pass
- else:
- sample_tensors.append(sample[i])
-
- #and finally solidify into a tuple
- sample_tensors=tuple(sample_tensors)
-
- return sample_tensors
-
-original_path=sys.path
-original_dir=os.getcwd()
-
-app_socket = tc.connect()
-print("Dataset script connected\n", file=sys.stderr)
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
-
- if msg_type == 'SetCurrentDir':
- new_dir=tc.decode_strings(msg_data)[0]
- sys.path=original_path
- os.chdir(original_dir)
- if new_dir:
- sys.path.append(new_dir)
- os.chdir(new_dir)
-
- if msg_type == 'SetDatasetCode':
- print("Loading dataset...\n", file=sys.stderr)
-
- meta_dataset = None
- error_msg, dataset_env = safe_exec(tc.decode_strings(msg_data)[0], description='dataset definition')
- if error_msg is not None or 'train' not in dataset_env:
- print("Unknown dataset definition error" if error_msg is None else error_msg, file=sys.stderr)
- else:
- meta_dataset=MetaDataset(dataset_env['train'], dataset_env['valid'] if 'valid' in dataset_env else None)
- tc.send_msg(app_socket, 'Labels', tc.encode_strings(meta_dataset.classes))
- tc.send_msg(app_socket, 'NumSamples', tc.encode_ints([len(meta_dataset.train()),len(meta_dataset.valid())]))
- sample=meta_dataset.train()[0]
-
- #suggest default formats
- type_id=[1 for i in range(len(sample))] #inputs
- if len(sample)==1:
- type_id[-1]=3 #input/output
- if len(sample)>1:
- type_id[-1]=2 #output
- tc.send_msg(app_socket, 'SetTypes', tc.encode_ints(type_id))
-
- renderer_name=[]
- for tensor in sample:
- if len(tensor.shape)==4:
- renderer_name.append("Volume")
- elif len(tensor.shape)==3 and tensor.dtype==torch.complex64:
- renderer_name.append("Spectrogram")
- elif len(tensor.shape)==3:
- renderer_name.append("Bitmap")
- elif len(tensor.shape)==2:
- renderer_name.append("Signal")
- elif len(tensor.shape)<2:
- renderer_name.append("Labels")
- else:
- renderer_name.append("Custom")
- tc.send_msg(app_socket, 'SetRendererNames', tc.encode_strings(renderer_name))
-
- if sample and len(sample[-1].shape)==0 and "int" in str(sample[-1].dtype):
- analyzer_name="Multiclass"
- elif sample and len(sample[-1].shape)==1:
- analyzer_name="MultiLabel"
- else:
- analyzer_name="ValuesDistribution"
- tc.send_msg(app_socket, 'SetAnalyzerName', tc.encode_strings(analyzer_name))
-
- print("Loading complete")
-
- if msg_type == 'RequestTrainingSamples' or msg_type == 'RequestValidationSamples':
- if meta_dataset is not None:
- meta_dataset.train(msg_type == 'RequestTrainingSamples')
- samples_id = tc.decode_ints(msg_data)
- for id in samples_id:
- tc.send_msg(app_socket, 'TensorData', tc.encode_torch_tensors(meta_dataset[id]))
-
- if msg_type == 'SetNumTrainingSamples':
- if meta_dataset is not None:
- meta_dataset.set_num_train(tc.decode_ints(msg_data)[0])
-
- if msg_type == 'SetSampleUsage':
- if meta_dataset is not None:
- meta_dataset.set_smp_usage(tc.decode_floats(msg_data)[0])
-
- if msg_type == 'SetShuffleMode':
- if meta_dataset is not None:
- meta_dataset.set_shuffle(tc.decode_ints(msg_data)[0])
-
- if msg_type == 'InputTensorsID':
- input_tensors_id = tc.decode_ints(msg_data)
-
- if msg_type == 'OutputTensorsID':
- output_tensors_id = tc.decode_ints(msg_data)
-
- if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples':
- train_set=True if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendAllSamples' else False
- valid_set=True if msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples' else False
- name, sshaddress, sshport, username, password, keydata, address, port = tc.decode_strings(msg_data)
- port=int(port)
-
- print('Connecting to '+name+'...\n', file=sys.stderr)
-
- if sshaddress and sshport and username:
- import socket
- import paramiko
- import torchstudio.sshtunnel as sshtunnel
-
- if not password:
- password=None
- if not keydata:
- pkey=None
- else:
- import io
- keybuffer=io.StringIO(keydata)
- pkey=paramiko.RSAKey.from_private_key(keybuffer)
-
- sshclient = paramiko.SSHClient()
- sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5)
- worker_socket = socket.socket()
- worker_socket.bind(('localhost', 0))
- freeport=worker_socket.getsockname()[1]
- worker_socket.close()
- forward_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ForwardTunnel, 'localhost', freeport, address if address else 'localhost', port)
- port=freeport
-
- try:
- worker_socket = tc.connect((address,port))
- num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0)
- tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples))
- tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id))
- tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id))
- tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes))
-
- tc.send_msg(worker_socket, 'StartSending')
- with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar:
- if train_set:
- meta_dataset.train()
- for i in range(len(meta_dataset)):
- tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i]))
- pbar.update(1)
- if valid_set:
- meta_dataset.valid()
- for i in range(len(meta_dataset)):
- tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i]))
- pbar.update(1)
-
- tc.send_msg(worker_socket, 'DoneSending')
- train_msg_type, train_msg_data = tc.recv_msg(worker_socket)
- if train_msg_type == 'DoneReceiving':
- worker_socket.close()
- print('Samples transfer to '+name+' completed')
-
- except:
- if sshaddress and sshport and username:
- time.sleep(.5) #let some time for threaded ssh error messages to print first
- print('Samples transfer to '+name+' interrupted', file=sys.stderr)
-
- if sshaddress and sshport and username:
- try:
- forward_tunnel.stop()
- except:
- pass
- try:
- sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong
- except:
- pass
-
- if msg_type == 'Exit':
- break
-
+#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490
+import os
+os.environ['KMP_DUPLICATE_LIB_OK']='True'
+
+import sys
+print("Loading PyTorch...\n", file=sys.stderr)
+
+import torch
+from torch.utils.data import Dataset
+from torchvision.transforms.functional import to_tensor
+import torchstudio.tcpcodec as tc
+from torchstudio.modules import safe_exec
+import random
+import os
+import io
+import time
+from collections.abc import Iterable
+from tqdm.auto import tqdm
+
+#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+
+meta_dataset = None
+input_tensors_id = []
+output_tensors_id = []
+
+class MetaDataset(Dataset):
+ def __init__(self, train, valid=None):
+ self.train_dataset=train
+ self.valid_dataset=valid
+ self.train_count=None
+ self.shuffle=0
+ self.smp_usage=1.0
+ self.training=True
+ self.classes=train.classes if hasattr(train,'classes') else []
+ self._gen_index()
+
+ def _gen_index(self):
+ self.index = []
+ for i in range(len(self.train_dataset)):
+ self.index.append((self.train_dataset,i))
+ if self.valid_dataset is not None:
+ for i in range(len(self.valid_dataset)):
+ self.index.append((self.valid_dataset,i))
+ if self.train_count is None:
+ self.train_count=len(self.train_dataset)
+ if self.valid_dataset is None:
+ self.train_count=round(self.train_count*0.8)
+
+ if self.shuffle>0:
+ #Fisher–Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
+ random.seed(0)
+ shuffle_count=self.train_count if self.shuffle==1 else len(self.index)
+ for sample in range(shuffle_count):
+ target_sample=random.randrange(sample,shuffle_count)
+ self.index[sample], self.index[target_sample] = self.index[target_sample], self.index[sample]
+
+ def set_num_train(self, num):
+ self.train_count=min(num,len(self.index))
+ self._gen_index()
+
+ def set_smp_usage(self, ratio):
+ self.smp_usage=min(max(ratio,0.0),1.0)
+ self._gen_index()
+
+ def set_shuffle(self, mode):
+ self.shuffle=mode
+ self._gen_index()
+
+ def train(self, mode=True):
+ self.training=mode
+ return self
+
+ def valid(self):
+ self.training=False
+ return self
+
+ def __len__(self):
+ if self.training==True:
+ return round(self.train_count*self.smp_usage)
+ else:
+ return round((len(self.index)-self.train_count)*self.smp_usage)
+
+ def __getitem__(self, id):
+ if id<0 or id>=len(self):
+ raise IndexError
+ if self.training==True:
+ sample_ref=self.index[id]
+ else:
+ sample_ref=self.index[id+self.train_count]
+ sample=sample_ref[0][sample_ref[1]]
+
+ #convert to list if needed
+ if isinstance(sample, Iterable):
+ if type(sample) is dict:
+ sample=list(sample.values())
+ else:
+ sample=list(sample)
+ else:
+ sample=[sample]
+
+ #convert each element of the list to a tensor if needed
+ sample_tensors=[]
+ for i in range(len(sample)):
+ if type(sample[i]) is not torch.Tensor:
+ if 'PIL' in str(type(sample[i])) or 'numpy' in str(type(sample[i])):
+ sample_tensors.append(to_tensor(sample[i]))
+ else:
+ try:
+ sample_tensors.append(torch.tensor(sample[i]))
+ except:
+ pass
+ else:
+ sample_tensors.append(sample[i])
+
+ #and finally solidify into a tuple
+ sample_tensors=tuple(sample_tensors)
+
+ return sample_tensors
+
+original_path=sys.path
+original_dir=os.getcwd()
+
+app_socket = tc.connect()
+print("Dataset script connected\n", file=sys.stderr)
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+
+ if msg_type == 'SetCurrentDir':
+ new_dir=tc.decode_strings(msg_data)[0]
+ sys.path=original_path
+ os.chdir(original_dir)
+ if new_dir:
+ sys.path.append(new_dir)
+ os.chdir(new_dir)
+
+ if msg_type == 'SetDatasetCode':
+ print("Loading dataset...\n", file=sys.stderr)
+
+ meta_dataset = None
+ error_msg, dataset_env = safe_exec(tc.decode_strings(msg_data)[0], description='dataset definition')
+ if error_msg is not None or 'train' not in dataset_env:
+ print("Unknown dataset definition error" if error_msg is None else error_msg, file=sys.stderr)
+ else:
+ meta_dataset=MetaDataset(dataset_env['train'], dataset_env['valid'] if 'valid' in dataset_env else None)
+ tc.send_msg(app_socket, 'Labels', tc.encode_strings(meta_dataset.classes))
+ tc.send_msg(app_socket, 'NumSamples', tc.encode_ints([len(meta_dataset.train()),len(meta_dataset.valid())]))
+ sample=meta_dataset.train()[0]
+
+ #suggest default formats
+ type_id=[1 for i in range(len(sample))] #inputs
+ if len(sample)==1:
+ type_id[-1]=3 #input/output
+ if len(sample)>1:
+ type_id[-1]=2 #output
+ tc.send_msg(app_socket, 'SetTypes', tc.encode_ints(type_id))
+
+ renderer_name=[]
+ for tensor in sample:
+ if len(tensor.shape)==4:
+ renderer_name.append("Volume")
+ elif len(tensor.shape)==3 and tensor.dtype==torch.complex64:
+ renderer_name.append("Spectrogram")
+ elif len(tensor.shape)==3:
+ renderer_name.append("Bitmap")
+ elif len(tensor.shape)==2:
+ renderer_name.append("Signal")
+ elif len(tensor.shape)<2:
+ renderer_name.append("Labels")
+ else:
+ renderer_name.append("Custom")
+ tc.send_msg(app_socket, 'SetRendererNames', tc.encode_strings(renderer_name))
+
+ if sample and len(sample[-1].shape)==0 and "int" in str(sample[-1].dtype):
+ analyzer_name="Multiclass"
+ elif sample and len(sample[-1].shape)==1:
+ analyzer_name="MultiLabel"
+ else:
+ analyzer_name="ValuesDistribution"
+ tc.send_msg(app_socket, 'SetAnalyzerName', tc.encode_strings(analyzer_name))
+
+ print("Loading complete")
+
+ if msg_type == 'RequestTrainingSamples' or msg_type == 'RequestValidationSamples':
+ if meta_dataset is not None:
+ meta_dataset.train(msg_type == 'RequestTrainingSamples')
+ samples_id = tc.decode_ints(msg_data)
+ for id in samples_id:
+ tc.send_msg(app_socket, 'TensorData', tc.encode_torch_tensors(meta_dataset[id]))
+
+ if msg_type == 'SetNumTrainingSamples':
+ if meta_dataset is not None:
+ meta_dataset.set_num_train(tc.decode_ints(msg_data)[0])
+
+ if msg_type == 'SetSampleUsage':
+ if meta_dataset is not None:
+ meta_dataset.set_smp_usage(tc.decode_floats(msg_data)[0])
+
+ if msg_type == 'SetShuffleMode':
+ if meta_dataset is not None:
+ meta_dataset.set_shuffle(tc.decode_ints(msg_data)[0])
+
+ if msg_type == 'InputTensorsID':
+ input_tensors_id = tc.decode_ints(msg_data)
+
+ if msg_type == 'OutputTensorsID':
+ output_tensors_id = tc.decode_ints(msg_data)
+
+ if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples':
+ train_set=True if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendAllSamples' else False
+ valid_set=True if msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples' else False
+ name, sshaddress, sshport, username, password, keydata, address, port = tc.decode_strings(msg_data)
+ port=int(port)
+
+ print('Connecting to '+name+'...\n', file=sys.stderr)
+
+ if sshaddress and sshport and username:
+ import socket
+ import paramiko
+ import torchstudio.sshtunnel as sshtunnel
+
+ if not password:
+ password=None
+ if not keydata:
+ pkey=None
+ else:
+ import io
+ keybuffer=io.StringIO(keydata)
+ pkey=paramiko.RSAKey.from_private_key(keybuffer)
+
+ sshclient = paramiko.SSHClient()
+ sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5)
+ worker_socket = socket.socket()
+ worker_socket.bind(('localhost', 0))
+ freeport=worker_socket.getsockname()[1]
+ worker_socket.close()
+ forward_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ForwardTunnel, 'localhost', freeport, address if address else 'localhost', port)
+ port=freeport
+
+ try:
+ worker_socket = tc.connect((address,port))
+ num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0)
+ tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples))
+ tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id))
+ tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id))
+ tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes))
+
+ tc.send_msg(worker_socket, 'StartSending')
+ with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar:
+ if train_set:
+ meta_dataset.train()
+ for i in range(len(meta_dataset)):
+ tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i]))
+ pbar.update(1)
+ if valid_set:
+ meta_dataset.valid()
+ for i in range(len(meta_dataset)):
+ tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i]))
+ pbar.update(1)
+
+ tc.send_msg(worker_socket, 'DoneSending')
+ train_msg_type, train_msg_data = tc.recv_msg(worker_socket)
+ if train_msg_type == 'DoneReceiving':
+ worker_socket.close()
+ print('Samples transfer to '+name+' completed')
+
+ except:
+ if sshaddress and sshport and username:
+ time.sleep(.5) #let some time for threaded ssh error messages to print first
+ print('Samples transfer to '+name+' interrupted', file=sys.stderr)
+
+ if sshaddress and sshport and username:
+ try:
+ forward_tunnel.stop()
+ except:
+ pass
+ try:
+ sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong
+ except:
+ pass
+
+ if msg_type == 'Exit':
+ break
+
diff --git a/torchstudio/datasets/randomgenerator.py b/torchstudio/datasets/randomgenerator.py
index 482d72b..fbb7b9b 100644
--- a/torchstudio/datasets/randomgenerator.py
+++ b/torchstudio/datasets/randomgenerator.py
@@ -1,50 +1,50 @@
-import torch
-from torch.utils.data import Dataset
-import inspect
-
-class RandomGenerator(Dataset):
- """A random generator that returns randomly generated tensors
-
- Args:
- size (int):
- Size of the dataset (number of samples)
- tensors:
- A list of tuples defining tensor properties: shape, type, range
- All properties are optionals. Defaults are null, float, [0,1]
- """
-
- def __init__(self, size:int=256, tensors=[(3,64,64), (int,[0,9])]):
- torch.manual_seed(0)
- self.size = size
- self.tensors = tensors
-
- def __len__(self):
- return self.size
-
- def __getitem__(self, idx):
- """
- Returns:
- A tuple of tensors.
- """
- sample = []
- for properties in self.tensors:
- shape=[]
- dtype=float
- drange=[0,1]
- for property in properties:
- if type(property)==int:
- shape.append(property)
- elif inspect.isclass(property):
- dtype=property
- elif type(property) is list:
- drange=property
- shape=tuple(shape)
-
- if 'int' in str(dtype):
- tensor=torch.randint(low=drange[0], high=drange[1]+1, size=shape, dtype=dtype)
- else:
- tensor=torch.rand(size=shape,dtype=dtype)*(drange[1]-drange[0])+drange[0]
-
- sample.append(tensor)
-
- return tuple(sample)
+import torch
+from torch.utils.data import Dataset
+import inspect
+
+class RandomGenerator(Dataset):
+ """A random generator that returns randomly generated tensors
+
+ Args:
+ size (int):
+ Size of the dataset (number of samples)
+ tensors:
+ A list of tuples defining tensor properties: shape, type, range
+ All properties are optionals. Defaults are null, float, [0,1]
+ """
+
+ def __init__(self, size:int=256, tensors=[(3,64,64), (int,[0,9])]):
+ torch.manual_seed(0)
+ self.size = size
+ self.tensors = tensors
+
+ def __len__(self):
+ return self.size
+
+ def __getitem__(self, idx):
+ """
+ Returns:
+ A tuple of tensors.
+ """
+ sample = []
+ for properties in self.tensors:
+ shape=[]
+ dtype=float
+ drange=[0,1]
+ for property in properties:
+ if type(property)==int:
+ shape.append(property)
+ elif inspect.isclass(property):
+ dtype=property
+ elif type(property) is list:
+ drange=property
+ shape=tuple(shape)
+
+ if 'int' in str(dtype):
+ tensor=torch.randint(low=drange[0], high=drange[1]+1, size=shape, dtype=dtype)
+ else:
+ tensor=torch.rand(size=shape,dtype=dtype)*(drange[1]-drange[0])+drange[0]
+
+ sample.append(tensor)
+
+ return tuple(sample)
diff --git a/torchstudio/graphdraw.py b/torchstudio/graphdraw.py
index cc9c34e..fffc353 100644
--- a/torchstudio/graphdraw.py
+++ b/torchstudio/graphdraw.py
@@ -1,638 +1,638 @@
-import torchstudio.tcpcodec as tc
-import os
-import graphviz
-import copy
-
-#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/torch.rst
-#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/nn.rst
-#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/nn.functional.rst
-
-creation_ops="""
-tensor
-sparse_coo_tensor
-as_tensor
-as_strided
-from_numpy
-frombuffer
-zeros
-zeros_like
-ones
-ones_like
-arange
-range
-linspace
-logspace
-eye
-empty
-empty_like
-empty_strided
-full
-full_like
-quantize_per_tensor
-quantize_per_channel
-dequantize
-complex
-polar
-heaviside
-"""
-
-manipulation_ops="""
-cat
-concat
-conj
-chunk
-dsplit
-column_stack
-dstack
-gather
-hsplit
-hstack
-index_select
-masked_select
-movedim
-moveaxis
-narrow
-nonzero
-permute
-reshape
-row_stack
-scatter
-scatter_add
-split
-squeeze
-stack
-swapaxes
-swapdims
-t
-take
-take_along_dim
-tensor_split
-tile
-transpose
-unbind
-unsqueeze
-vsplit
-vstack
-where
-"""
-
-convolution_ops="""
-nn.Conv1d
-nn.Conv2d
-nn.Conv3d
-nn.ConvTranspose1d
-nn.ConvTranspose2d
-nn.ConvTranspose3d
-nn.LazyConv1d
-nn.LazyConv2d
-nn.LazyConv3d
-nn.LazyConvTranspose1d
-nn.LazyConvTranspose2d
-nn.LazyConvTranspose3d
-nn.Unfold
-nn.Fold
-conv1d
-conv2d
-conv3d
-conv_transpose1d
-conv_transpose2d
-conv_transpose3d
-unfold
-fold
-"""
-
-pooling_ops="""
-nn.MaxPool1d
-nn.MaxPool2d
-nn.MaxPool3d
-nn.MaxUnpool1d
-nn.MaxUnpool2d
-nn.MaxUnpool3d
-nn.AvgPool1d
-nn.AvgPool2d
-nn.AvgPool3d
-nn.FractionalMaxPool2d
-nn.FractionalMaxPool3d
-nn.LPPool1d
-nn.LPPool2d
-nn.AdaptiveMaxPool1d
-nn.AdaptiveMaxPool2d
-nn.AdaptiveMaxPool3d
-nn.AdaptiveAvgPool1d
-nn.AdaptiveAvgPool2d
-nn.AdaptiveAvgPool3d
-avg_pool1d
-avg_pool2d
-avg_pool3d
-max_pool1d
-max_pool2d
-max_pool3d
-max_unpool1d
-max_unpool2d
-max_unpool3d
-lp_pool1d
-lp_pool2d
-adaptive_max_pool1d
-adaptive_max_pool2d
-adaptive_max_pool3d
-adaptive_avg_pool1d
-adaptive_avg_pool2d
-adaptive_avg_pool3d
-fractional_max_pool2d
-fractional_max_pool3d
-"""
-
-activation_ops="""
-nn.ELU
-nn.Hardshrink
-nn.Hardsigmoid
-nn.Hardtanh
-nn.Hardswish
-nn.LeakyReLU
-nn.LogSigmoid
-nn.MultiheadAttention
-nn.PReLU
-nn.ReLU
-nn.ReLU6
-nn.RReLU
-nn.SELU
-nn.CELU
-nn.GELU
-nn.Sigmoid
-nn.SiLU
-nn.Mish
-nn.Softplus
-nn.Softshrink
-nn.Softsign
-nn.Tanh
-nn.Tanhshrink
-nn.Threshold
-nn.GLU
-nn.Softmin
-nn.Softmax
-nn.Softmax2d
-nn.LogSoftmax
-nn.AdaptiveLogSoftmaxWithLoss
-threshold
-threshold_
-relu
-relu_
-hardtanh
-hardtanh_
-hardswish
-relu6
-elu
-elu_
-selu
-celu
-leaky_relu
-leaky_relu_
-prelu
-rrelu
-rrelu_
-glu
-gelu
-logsigmoid
-hardshrink
-tanhshrink
-softsign
-softplus
-softmin
-softmax
-softshrink
-gumbel_softmax
-log_softmax
-tanh
-sigmoid
-hardsigmoid
-silu
-mish
-batch_norm
-group_norm
-instance_norm
-layer_norm
-local_response_norm
-normalize
-"""
-
-normalization_ops="""
-nn.BatchNorm1d
-nn.BatchNorm2d
-nn.BatchNorm3d
-nn.LazyBatchNorm1d
-nn.LazyBatchNorm2d
-nn.LazyBatchNorm3d
-nn.GroupNorm
-nn.SyncBatchNorm
-nn.InstanceNorm1d
-nn.InstanceNorm2d
-nn.InstanceNorm3d
-nn.LazyInstanceNorm1d
-nn.LazyInstanceNorm2d
-nn.LazyInstanceNorm3d
-nn.LayerNorm
-nn.LocalResponseNorm
-"""
-
-linear_ops="""
-nn.Identity
-nn.Linear
-nn.Bilinear
-nn.LazyLinear
-linear
-bilinear
-"""
-
-dropout_ops="""
-nn.Dropout
-nn.Dropout2d
-nn.Dropout3d
-nn.AlphaDropout
-nn.FeatureAlphaDropout
-dropout
-alpha_dropout
-feature_alpha_dropout
-dropout2d
-dropout3d
-"""
-
-vision_ops="""
-nn.PixelShuffle
-nn.PixelUnshuffle
-nn.Upsample
-nn.UpsamplingNearest2d
-nn.UpsamplingBilinear2d
-pixel_shuffle
-pixel_unshuffle
-pad
-interpolate
-upsample
-upsample_nearest
-upsample_bilinear
-grid_sample
-affine_grid
-"""
-
-math_ops="""
-abs
-absolute
-acos
-arccos
-acosh
-arccosh
-add
-addcdiv
-addcmul
-angle
-asin
-arcsin
-asinh
-arcsinh
-atan
-arctan
-atanh
-arctanh
-atan2
-bitwise_not
-bitwise_and
-bitwise_or
-bitwise_xor
-bitwise_left_shift
-bitwise_right_shift
-ceil
-clamp
-clip
-conj_physical
-copysign
-cos
-cosh
-deg2rad
-div
-divide
-digamma
-erf
-erfc
-erfinv
-exp
-exp2
-expm1
-fake_quantize_per_channel_affine
-fake_quantize_per_tensor_affine
-fix
-float_power
-floor
-floor_divide
-fmod
-frac
-frexp
-gradient
-imag
-ldexp
-lerp
-lgamma
-log
-log10
-log1p
-log2
-logaddexp
-logaddexp2
-logical_and
-logical_not
-logical_or
-logical_xor
-logit
-hypot
-i0
-igamma
-igammac
-mul
-multiply
-mvlgamma
-nan_to_num
-neg
-negative
-nextafter
-polygamma
-positive
-pow
-quantized_batch_norm
-quantized_max_pool1d
-quantized_max_pool2d
-rad2deg
-real
-reciprocal
-remainder
-round
-rsqrt
-sigmoid
-sign
-sgn
-signbit
-sin
-sinc
-sinh
-sqrt
-square
-sub
-subtract
-tan
-tanh
-true_divide
-trunc
-xlogy
-"""
-
-reduction_ops="""
-argmax
-argmin
-amax
-amin
-aminmax
-all
-any
-max
-min
-dist
-logsumexp
-mean
-nanmean
-median
-nanmedian
-mode
-norm
-nansum
-prod
-quantile
-nanquantile
-std
-std_mean
-sum
-unique
-unique_consecutive
-var
-var_mean
-count_nonzero
-"""
-
-comparison_ops="""
-allclose
-argsort
-eq
-equal
-ge
-greater_equal
-gt
-greater
-isclose
-isfinite
-isin
-isinf
-isposinf
-isneginf
-isnan
-isreal
-kthvalue
-le
-less_equal
-lt
-less
-maximum
-minimum
-fmax
-fmin
-ne
-not_equal
-sort
-topk
-msort
-"""
-
-other_ops="""
-atleast_1d
-atleast_2d
-atleast_3d
-bincount
-block_diag
-broadcast_tensors
-broadcast_to
-broadcast_shapes
-bucketize
-cartesian_prod
-cdist
-clone
-combinations
-corrcoef
-cov
-cross
-cummax
-cummin
-cumprod
-cumsum
-diag
-diag_embed
-diagflat
-diagonal
-diff
-einsum
-flatten
-flip
-fliplr
-flipud
-kron
-rot90
-gcd
-histc
-histogram
-meshgrid
-lcm
-logcumsumexp
-ravel
-renorm
-repeat_interleave
-roll
-searchsorted
-tensordot
-trace
-tril
-tril_indices
-triu
-triu_indices
-vander
-view_as_real
-view_as_complex
-resolve_conj
-resolve_neg
-"""
-
-creation_ops=[op.split('.')[-1] for op in creation_ops.split('\n') if op]
-manipulation_ops=[op.split('.')[-1] for op in manipulation_ops.split('\n') if op]
-convolution_ops=[op.split('.')[-1] for op in convolution_ops.split('\n') if op]
-pooling_ops=[op.split('.')[-1] for op in pooling_ops.split('\n') if op]
-activation_ops=[op.split('.')[-1] for op in activation_ops.split('\n') if op]
-normalization_ops=[op.split('.')[-1] for op in normalization_ops.split('\n') if op]
-linear_ops=[op.split('.')[-1] for op in linear_ops.split('\n') if op]
-dropout_ops=[op.split('.')[-1] for op in dropout_ops.split('\n') if op]
-vision_ops=[op.split('.')[-1] for op in vision_ops.split('\n') if op]
-math_ops=[op.split('.')[-1] for op in math_ops.split('\n') if op]
-reduction_ops=[op.split('.')[-1] for op in reduction_ops.split('\n') if op]
-comparison_ops=[op.split('.')[-1] for op in comparison_ops.split('\n') if op]
-other_ops=[op.split('.')[-1] for op in other_ops.split('\n') if op]
-
-ops_color={}
-ops_color.update({op : '#707070' for op in creation_ops})
-ops_color.update({op : '#803080' for op in manipulation_ops})
-ops_color.update({op : '#3080c0' for op in convolution_ops})
-ops_color.update({op : '#109010' for op in pooling_ops})
-ops_color.update({op : '#b03030' for op in activation_ops})
-ops_color.update({op : '#6080a0' for op in normalization_ops})
-ops_color.update({op : '#30b060' for op in linear_ops})
-ops_color.update({op : '#c09020' for op in dropout_ops})
-ops_color.update({op : '#509090' for op in vision_ops})
-ops_color.update({op : '#d06000' for op in math_ops})
-ops_color.update({op : '#906000' for op in reduction_ops})
-ops_color.update({op : '#90a060' for op in comparison_ops})
-ops_color.update({op : '#b03070' for op in other_ops})
-
-text_color="#f0f0f0"
-default_color="#908070";
-input_color="#606060"
-output_color="#808080"
-
-link_text_color="#d0d0d0"
-link_color="#a0a0a0"
-
-app_socket = tc.connect()
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
-
- if msg_type == 'SetGraph':
- nodes=eval(str(msg_data,'utf-8'))
-
- if msg_type == 'Render':
- batch, legend = tc.decode_ints(msg_data)
- filtered_nodes=copy.deepcopy(nodes)
-
- #merge referenced getitems
- for id, node in nodes.items():
- filtered_nodes[id]['input_shape']={}
- for input in node['inputs']:
- if nodes[input]['op_module']=='operator' and nodes[input]['op']=='getitem':
- for sub_input in nodes[input]['inputs']:
- filtered_nodes[id]['inputs'].remove(input)
- filtered_nodes[id]['inputs'].append(sub_input)
- filtered_nodes[id]['input_shape'][sub_input]=nodes[input]['output_shape']
- del filtered_nodes[input]
- #del non-referenced getitems
- nodes=copy.deepcopy(filtered_nodes)
- for id, node in nodes.items():
- if node['op_module']=='operator' and node['op']=='getitem':
- del filtered_nodes[id]
- graph = graphviz.Digraph(graph_attr={'peripheries':'0', 'dpi': '0.0', 'bgcolor': 'transparent', 'ranksep': '0.25', 'margin': '0'},
- node_attr={'style': 'filled', 'shape': 'Mrecord', 'fillcolor': default_color, 'penwidth':'0', 'fontcolor': text_color,'fontsize':'20', 'fontname':'Source Code Pro'},
- edge_attr={'color': link_color, 'fontcolor': link_text_color,'fontsize':'16', 'fontname':'Source Code Pro'})
- inputs_graph = graphviz.Digraph(name='cluster_input', node_attr={'shape': 'oval', 'fillcolor': input_color, 'margin': '0'})
- outputs_graph = graphviz.Digraph(name='cluster_output', node_attr={'shape': 'oval', 'fillcolor': output_color, 'margin': '0'})
-
- for id, node in filtered_nodes.items():
- if node['type']=='input':
- inputs_graph.node(id, '<'+node['name']+'>', tooltip=node['name'])
- elif node['type']=='output':
- outputs_graph.node(id, '<'+node['name']+'>')
- else:
- if node['op'] in ops_color:
- node_color=ops_color[node['op']]
- else:
- node_color=default_color
- label=node['op']
- label_start, label_end=('<','>') if node['type']=='function' else ('','')
- if node['op']=='':
- node_tooltip=node['name']
- else:
- node_tooltip=(node['name']+' = ' if node['type']=='module' else '')+node['op_module']+"."+node['op']+'('+node['params']+')'
- graph.node(id, label_start+label+label_end, {'fillcolor': node_color, 'tooltip': node_tooltip})
-
- graph.subgraph(inputs_graph)
- graph.subgraph(outputs_graph)
-
- for id, node in filtered_nodes.items():
- for input in node['inputs']:
- output_shape=filtered_nodes[input]['output_shape']
- if input in node['input_shape']:
- output_shape=node['input_shape'][input]
- if batch==1:
- output_shape=('N,' if output_shape else 'N')+output_shape
- graph.edge(input,id," "+output_shape.replace(',','\u00d7')) #replace comma by multiplication sign
-
- if legend==1:
- with graph.subgraph(name='cluster_legend', node_attr={'shape': 'box', 'margin':'0', 'fontsize':'16', 'style':''}) as legend:
- table= '
Input | '
- table+=' Creation |
'
- table+='Manipulation | '
- table+=' Convolution |
'
- table+='Pooling | '
- table+=' Activation |
'
- table+='Normalization | '
- table+=' Linear |
'
- table+='Dropout | '
- table+=' Vision |
'
- table+='Math | '
- table+=' Reduction |
'
- table+='Comparison | '
- table+=' Other |
'
- table+='Unknown | '
- table+=' Output |
'
- legend.node('legend', '<>')
-
- svg=graph.pipe(format='svg')
- tc.send_msg(app_socket, 'SVGData', svg)
-
-# with open('/Users/divide/Documents/output.txt','w') as file:
-# print(graph.source, file=file)
-# with open('/Users/divide/Documents/output.svg','w') as file:
-# print(str(svg, 'utf-8'), file=file)
-# with open('/Users/divide/Documents/output.png','wb') as file:
-# file.write(graph.pipe(format='png'))
-
- if msg_type == 'Exit':
- break
-
+import torchstudio.tcpcodec as tc
+import os
+import graphviz
+import copy
+
+#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/torch.rst
+#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/nn.rst
+#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/nn.functional.rst
+
+creation_ops="""
+tensor
+sparse_coo_tensor
+as_tensor
+as_strided
+from_numpy
+frombuffer
+zeros
+zeros_like
+ones
+ones_like
+arange
+range
+linspace
+logspace
+eye
+empty
+empty_like
+empty_strided
+full
+full_like
+quantize_per_tensor
+quantize_per_channel
+dequantize
+complex
+polar
+heaviside
+"""
+
+manipulation_ops="""
+cat
+concat
+conj
+chunk
+dsplit
+column_stack
+dstack
+gather
+hsplit
+hstack
+index_select
+masked_select
+movedim
+moveaxis
+narrow
+nonzero
+permute
+reshape
+row_stack
+scatter
+scatter_add
+split
+squeeze
+stack
+swapaxes
+swapdims
+t
+take
+take_along_dim
+tensor_split
+tile
+transpose
+unbind
+unsqueeze
+vsplit
+vstack
+where
+"""
+
+convolution_ops="""
+nn.Conv1d
+nn.Conv2d
+nn.Conv3d
+nn.ConvTranspose1d
+nn.ConvTranspose2d
+nn.ConvTranspose3d
+nn.LazyConv1d
+nn.LazyConv2d
+nn.LazyConv3d
+nn.LazyConvTranspose1d
+nn.LazyConvTranspose2d
+nn.LazyConvTranspose3d
+nn.Unfold
+nn.Fold
+conv1d
+conv2d
+conv3d
+conv_transpose1d
+conv_transpose2d
+conv_transpose3d
+unfold
+fold
+"""
+
+pooling_ops="""
+nn.MaxPool1d
+nn.MaxPool2d
+nn.MaxPool3d
+nn.MaxUnpool1d
+nn.MaxUnpool2d
+nn.MaxUnpool3d
+nn.AvgPool1d
+nn.AvgPool2d
+nn.AvgPool3d
+nn.FractionalMaxPool2d
+nn.FractionalMaxPool3d
+nn.LPPool1d
+nn.LPPool2d
+nn.AdaptiveMaxPool1d
+nn.AdaptiveMaxPool2d
+nn.AdaptiveMaxPool3d
+nn.AdaptiveAvgPool1d
+nn.AdaptiveAvgPool2d
+nn.AdaptiveAvgPool3d
+avg_pool1d
+avg_pool2d
+avg_pool3d
+max_pool1d
+max_pool2d
+max_pool3d
+max_unpool1d
+max_unpool2d
+max_unpool3d
+lp_pool1d
+lp_pool2d
+adaptive_max_pool1d
+adaptive_max_pool2d
+adaptive_max_pool3d
+adaptive_avg_pool1d
+adaptive_avg_pool2d
+adaptive_avg_pool3d
+fractional_max_pool2d
+fractional_max_pool3d
+"""
+
+activation_ops="""
+nn.ELU
+nn.Hardshrink
+nn.Hardsigmoid
+nn.Hardtanh
+nn.Hardswish
+nn.LeakyReLU
+nn.LogSigmoid
+nn.MultiheadAttention
+nn.PReLU
+nn.ReLU
+nn.ReLU6
+nn.RReLU
+nn.SELU
+nn.CELU
+nn.GELU
+nn.Sigmoid
+nn.SiLU
+nn.Mish
+nn.Softplus
+nn.Softshrink
+nn.Softsign
+nn.Tanh
+nn.Tanhshrink
+nn.Threshold
+nn.GLU
+nn.Softmin
+nn.Softmax
+nn.Softmax2d
+nn.LogSoftmax
+nn.AdaptiveLogSoftmaxWithLoss
+threshold
+threshold_
+relu
+relu_
+hardtanh
+hardtanh_
+hardswish
+relu6
+elu
+elu_
+selu
+celu
+leaky_relu
+leaky_relu_
+prelu
+rrelu
+rrelu_
+glu
+gelu
+logsigmoid
+hardshrink
+tanhshrink
+softsign
+softplus
+softmin
+softmax
+softshrink
+gumbel_softmax
+log_softmax
+tanh
+sigmoid
+hardsigmoid
+silu
+mish
+batch_norm
+group_norm
+instance_norm
+layer_norm
+local_response_norm
+normalize
+"""
+
+normalization_ops="""
+nn.BatchNorm1d
+nn.BatchNorm2d
+nn.BatchNorm3d
+nn.LazyBatchNorm1d
+nn.LazyBatchNorm2d
+nn.LazyBatchNorm3d
+nn.GroupNorm
+nn.SyncBatchNorm
+nn.InstanceNorm1d
+nn.InstanceNorm2d
+nn.InstanceNorm3d
+nn.LazyInstanceNorm1d
+nn.LazyInstanceNorm2d
+nn.LazyInstanceNorm3d
+nn.LayerNorm
+nn.LocalResponseNorm
+"""
+
+linear_ops="""
+nn.Identity
+nn.Linear
+nn.Bilinear
+nn.LazyLinear
+linear
+bilinear
+"""
+
+dropout_ops="""
+nn.Dropout
+nn.Dropout2d
+nn.Dropout3d
+nn.AlphaDropout
+nn.FeatureAlphaDropout
+dropout
+alpha_dropout
+feature_alpha_dropout
+dropout2d
+dropout3d
+"""
+
+vision_ops="""
+nn.PixelShuffle
+nn.PixelUnshuffle
+nn.Upsample
+nn.UpsamplingNearest2d
+nn.UpsamplingBilinear2d
+pixel_shuffle
+pixel_unshuffle
+pad
+interpolate
+upsample
+upsample_nearest
+upsample_bilinear
+grid_sample
+affine_grid
+"""
+
+math_ops="""
+abs
+absolute
+acos
+arccos
+acosh
+arccosh
+add
+addcdiv
+addcmul
+angle
+asin
+arcsin
+asinh
+arcsinh
+atan
+arctan
+atanh
+arctanh
+atan2
+bitwise_not
+bitwise_and
+bitwise_or
+bitwise_xor
+bitwise_left_shift
+bitwise_right_shift
+ceil
+clamp
+clip
+conj_physical
+copysign
+cos
+cosh
+deg2rad
+div
+divide
+digamma
+erf
+erfc
+erfinv
+exp
+exp2
+expm1
+fake_quantize_per_channel_affine
+fake_quantize_per_tensor_affine
+fix
+float_power
+floor
+floor_divide
+fmod
+frac
+frexp
+gradient
+imag
+ldexp
+lerp
+lgamma
+log
+log10
+log1p
+log2
+logaddexp
+logaddexp2
+logical_and
+logical_not
+logical_or
+logical_xor
+logit
+hypot
+i0
+igamma
+igammac
+mul
+multiply
+mvlgamma
+nan_to_num
+neg
+negative
+nextafter
+polygamma
+positive
+pow
+quantized_batch_norm
+quantized_max_pool1d
+quantized_max_pool2d
+rad2deg
+real
+reciprocal
+remainder
+round
+rsqrt
+sigmoid
+sign
+sgn
+signbit
+sin
+sinc
+sinh
+sqrt
+square
+sub
+subtract
+tan
+tanh
+true_divide
+trunc
+xlogy
+"""
+
+reduction_ops="""
+argmax
+argmin
+amax
+amin
+aminmax
+all
+any
+max
+min
+dist
+logsumexp
+mean
+nanmean
+median
+nanmedian
+mode
+norm
+nansum
+prod
+quantile
+nanquantile
+std
+std_mean
+sum
+unique
+unique_consecutive
+var
+var_mean
+count_nonzero
+"""
+
+comparison_ops="""
+allclose
+argsort
+eq
+equal
+ge
+greater_equal
+gt
+greater
+isclose
+isfinite
+isin
+isinf
+isposinf
+isneginf
+isnan
+isreal
+kthvalue
+le
+less_equal
+lt
+less
+maximum
+minimum
+fmax
+fmin
+ne
+not_equal
+sort
+topk
+msort
+"""
+
+other_ops="""
+atleast_1d
+atleast_2d
+atleast_3d
+bincount
+block_diag
+broadcast_tensors
+broadcast_to
+broadcast_shapes
+bucketize
+cartesian_prod
+cdist
+clone
+combinations
+corrcoef
+cov
+cross
+cummax
+cummin
+cumprod
+cumsum
+diag
+diag_embed
+diagflat
+diagonal
+diff
+einsum
+flatten
+flip
+fliplr
+flipud
+kron
+rot90
+gcd
+histc
+histogram
+meshgrid
+lcm
+logcumsumexp
+ravel
+renorm
+repeat_interleave
+roll
+searchsorted
+tensordot
+trace
+tril
+tril_indices
+triu
+triu_indices
+vander
+view_as_real
+view_as_complex
+resolve_conj
+resolve_neg
+"""
+
+creation_ops=[op.split('.')[-1] for op in creation_ops.split('\n') if op]
+manipulation_ops=[op.split('.')[-1] for op in manipulation_ops.split('\n') if op]
+convolution_ops=[op.split('.')[-1] for op in convolution_ops.split('\n') if op]
+pooling_ops=[op.split('.')[-1] for op in pooling_ops.split('\n') if op]
+activation_ops=[op.split('.')[-1] for op in activation_ops.split('\n') if op]
+normalization_ops=[op.split('.')[-1] for op in normalization_ops.split('\n') if op]
+linear_ops=[op.split('.')[-1] for op in linear_ops.split('\n') if op]
+dropout_ops=[op.split('.')[-1] for op in dropout_ops.split('\n') if op]
+vision_ops=[op.split('.')[-1] for op in vision_ops.split('\n') if op]
+math_ops=[op.split('.')[-1] for op in math_ops.split('\n') if op]
+reduction_ops=[op.split('.')[-1] for op in reduction_ops.split('\n') if op]
+comparison_ops=[op.split('.')[-1] for op in comparison_ops.split('\n') if op]
+other_ops=[op.split('.')[-1] for op in other_ops.split('\n') if op]
+
+ops_color={}
+ops_color.update({op : '#707070' for op in creation_ops})
+ops_color.update({op : '#803080' for op in manipulation_ops})
+ops_color.update({op : '#3080c0' for op in convolution_ops})
+ops_color.update({op : '#109010' for op in pooling_ops})
+ops_color.update({op : '#b03030' for op in activation_ops})
+ops_color.update({op : '#6080a0' for op in normalization_ops})
+ops_color.update({op : '#30b060' for op in linear_ops})
+ops_color.update({op : '#c09020' for op in dropout_ops})
+ops_color.update({op : '#509090' for op in vision_ops})
+ops_color.update({op : '#d06000' for op in math_ops})
+ops_color.update({op : '#906000' for op in reduction_ops})
+ops_color.update({op : '#90a060' for op in comparison_ops})
+ops_color.update({op : '#b03070' for op in other_ops})
+
+text_color="#f0f0f0"
+default_color="#908070";
+input_color="#606060"
+output_color="#808080"
+
+link_text_color="#d0d0d0"
+link_color="#a0a0a0"
+
+app_socket = tc.connect()
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+
+ if msg_type == 'SetGraph':
+ nodes=eval(str(msg_data,'utf-8'))
+
+ if msg_type == 'Render':
+ batch, legend = tc.decode_ints(msg_data)
+ filtered_nodes=copy.deepcopy(nodes)
+
+ #merge referenced getitems
+ for id, node in nodes.items():
+ filtered_nodes[id]['input_shape']={}
+ for input in node['inputs']:
+ if nodes[input]['op_module']=='operator' and nodes[input]['op']=='getitem':
+ for sub_input in nodes[input]['inputs']:
+ filtered_nodes[id]['inputs'].remove(input)
+ filtered_nodes[id]['inputs'].append(sub_input)
+ filtered_nodes[id]['input_shape'][sub_input]=nodes[input]['output_shape']
+ del filtered_nodes[input]
+ #del non-referenced getitems
+ nodes=copy.deepcopy(filtered_nodes)
+ for id, node in nodes.items():
+ if node['op_module']=='operator' and node['op']=='getitem':
+ del filtered_nodes[id]
+ graph = graphviz.Digraph(graph_attr={'peripheries':'0', 'dpi': '0.0', 'bgcolor': 'transparent', 'ranksep': '0.25', 'margin': '0'},
+ node_attr={'style': 'filled', 'shape': 'Mrecord', 'fillcolor': default_color, 'penwidth':'0', 'fontcolor': text_color,'fontsize':'20', 'fontname':'Source Code Pro'},
+ edge_attr={'color': link_color, 'fontcolor': link_text_color,'fontsize':'16', 'fontname':'Source Code Pro'})
+ inputs_graph = graphviz.Digraph(name='cluster_input', node_attr={'shape': 'oval', 'fillcolor': input_color, 'margin': '0'})
+ outputs_graph = graphviz.Digraph(name='cluster_output', node_attr={'shape': 'oval', 'fillcolor': output_color, 'margin': '0'})
+
+ for id, node in filtered_nodes.items():
+ if node['type']=='input':
+ inputs_graph.node(id, '<'+node['name']+'>', tooltip=node['name'])
+ elif node['type']=='output':
+ outputs_graph.node(id, '<'+node['name']+'>')
+ else:
+ if node['op'] in ops_color:
+ node_color=ops_color[node['op']]
+ else:
+ node_color=default_color
+ label=node['op']
+ label_start, label_end=('<','>') if node['type']=='function' else ('','')
+ if node['op']=='':
+ node_tooltip=node['name']
+ else:
+ node_tooltip=(node['name']+' = ' if node['type']=='module' else '')+node['op_module']+"."+node['op']+'('+node['params']+')'
+ graph.node(id, label_start+label+label_end, {'fillcolor': node_color, 'tooltip': node_tooltip})
+
+ graph.subgraph(inputs_graph)
+ graph.subgraph(outputs_graph)
+
+ for id, node in filtered_nodes.items():
+ for input in node['inputs']:
+ output_shape=filtered_nodes[input]['output_shape']
+ if input in node['input_shape']:
+ output_shape=node['input_shape'][input]
+ if batch==1:
+ output_shape=('N,' if output_shape else 'N')+output_shape
+ graph.edge(input,id," "+output_shape.replace(',','\u00d7')) #replace comma by multiplication sign
+
+ if legend==1:
+ with graph.subgraph(name='cluster_legend', node_attr={'shape': 'box', 'margin':'0', 'fontsize':'16', 'style':''}) as legend:
+ table= 'Input | '
+ table+=' Creation |
'
+ table+='Manipulation | '
+ table+=' Convolution |
'
+ table+='Pooling | '
+ table+=' Activation |
'
+ table+='Normalization | '
+ table+=' Linear |
'
+ table+='Dropout | '
+ table+=' Vision |
'
+ table+='Math | '
+ table+=' Reduction |
'
+ table+='Comparison | '
+ table+=' Other |
'
+ table+='Unknown | '
+ table+=' Output |
'
+ legend.node('legend', '<>')
+
+ svg=graph.pipe(format='svg')
+ tc.send_msg(app_socket, 'SVGData', svg)
+
+# with open('/Users/divide/Documents/output.txt','w') as file:
+# print(graph.source, file=file)
+# with open('/Users/divide/Documents/output.svg','w') as file:
+# print(str(svg, 'utf-8'), file=file)
+# with open('/Users/divide/Documents/output.png','wb') as file:
+# file.write(graph.pipe(format='png'))
+
+ if msg_type == 'Exit':
+ break
+
diff --git a/torchstudio/metricsplot.py b/torchstudio/metricsplot.py
index 0b6c665..6c29630 100644
--- a/torchstudio/metricsplot.py
+++ b/torchstudio/metricsplot.py
@@ -1,185 +1,185 @@
-import torchstudio.tcpcodec as tc
-import inspect
-import sys
-import os
-
-import matplotlib as mpl
-import matplotlib.pyplot as plt
-from matplotlib.ticker import MaxNLocator
-import PIL
-
-def plot_metrics(prefix, size, dpi, samples=100, labels=[],
- loss=[], loss_colors=[], loss_shift=(0,0), loss_scale=(1,1),
- metric=[], metric_colors=[], metric_shift=(0,0), metric_scale=(1,1)):
- """Metrics Plot
-
- Usage:
- Drag: pan
- Scroll: zoom vertically
- """
- #set up matplotlib renderer, style, figure and axis
- mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/
- plt.style.use('dark_background')
- plt.rcParams.update({'font.size': 7})
-
- fig, [ax1, ax2] = plt.subplots(1 if size[0]>size[1] else 2, 2 if size[0]>size[1] else 1, figsize=(size[0]/dpi, size[1]/dpi), dpi=dpi)
-
- #LOSS
- ax1.set_title(prefix+"Loss")
-
- #fit
- loss_xmin=0
- loss_xmax=samples
- loss_ymin=0
- loss_ymax=1
- for l in loss:
- loss_xmax=max(loss_xmax,len(l))
-# if(len(l)>0):
-# loss_ymax=max(loss_ymax,max(l))
-
-# #shift
-# render_size=(loss_xmax-loss_xmin,loss_ymax-loss_ymin)
-# loss_xmin-=loss_shift[0]/loss_scale[0]*render_size[0]
-# loss_xmax-=loss_shift[0]/loss_scale[0]*render_size[0]
-# loss_ymin-=loss_shift[1]/loss_scale[1]*render_size[1]
-# loss_ymax-=loss_shift[1]/loss_scale[1]*render_size[1]
-
-# #scale
-# render_center=(loss_xmin+render_size[0]/2,loss_ymin+render_size[1]/2)
-# loss_xmin=render_center[0]-(render_size[0]/loss_scale[0]/2)
-# loss_xmax=render_center[0]+(render_size[0]/loss_scale[0]/2)
-# loss_ymin=render_center[1]-(render_size[1]/loss_scale[1]/2)
-# loss_ymax=render_center[1]+(render_size[1]/loss_scale[1]/2)
-
-# loss_xmin=max(0,loss_xmin)
-# loss_ymin=max(0,loss_ymin)
-
- loss_ymin-=loss_shift[1]/loss_scale[1]
- loss_ymax-=loss_shift[1]/loss_scale[1]
- loss_ymax=loss_ymax/loss_scale[1]
-
- ax1.axis(xmin=loss_xmin,xmax=loss_xmax,ymin=loss_ymin,ymax=loss_ymax)
- ax1.spines['top'].set_visible(False)
- ax1.spines['right'].set_visible(False)
- ax1.spines['left'].set_color('#707070')
- ax1.spines['bottom'].set_color('#707070')
- for i in range(len(loss)):
- ax1.plot(loss[i],label=str(i) if i>=len(labels) else labels[i],color=loss_colors[i%len(loss_colors)])
- if labels and loss and loss[0]:
- ax1.legend(bbox_to_anchor=(1, 1), loc='upper right', ncol=1, prop={'size': 8})
- ax1.grid(color = '#303030')
- ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
- ax1.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
-
- #METRIC
- ax2.set_title(prefix+"Metric")
-
- #fit
- metric_xmin=0
- metric_xmax=samples
- metric_ymin=0
- metric_ymax=1
- for m in metric:
- metric_xmax=max(metric_xmax,len(m))
-
-# #shift
-# render_size=(metric_xmax-metric_xmin,metric_ymax-metric_ymin)
-# metric_xmin-=metric_shift[0]/metric_scale[0]*render_size[0]
-# metric_xmax-=metric_shift[0]/metric_scale[0]*render_size[0]
-# metric_ymin-=metric_shift[1]/metric_scale[1]*render_size[1]
-# metric_ymax-=metric_shift[1]/metric_scale[1]*render_size[1]
-
-# #scale
-# render_center=(metric_xmin+render_size[0]/2,metric_ymin+render_size[1]/2)
-# metric_xmin=render_center[0]-(render_size[0]/metric_scale[0]/2)
-# metric_xmax=render_center[0]+(render_size[0]/metric_scale[0]/2)
-# metric_ymin=render_center[1]-(render_size[1]/metric_scale[1]/2)
-# metric_ymax=render_center[1]+(render_size[1]/metric_scale[1]/2)
-
-# metric_xmin=max(0,metric_xmin)
-
- metric_ymin-=metric_shift[1]/metric_scale[1]
- metric_ymax-=metric_shift[1]/metric_scale[1]
- metric_ymin=(metric_ymin-metric_ymax)/metric_scale[1]+metric_ymax
-
- ax2.axis(xmin=metric_xmin,xmax=metric_xmax,ymin=metric_ymin,ymax=metric_ymax)
- ax2.spines['top'].set_visible(False)
- ax2.spines['right'].set_visible(False)
- ax2.spines['left'].set_color('#707070')
- ax2.spines['bottom'].set_color('#707070')
- for i in range(len(metric)):
- ax2.plot(metric[i],color=metric_colors[i%len(metric_colors)])
- ax2.grid(color = '#303030')
- ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
-
- plt.tight_layout(pad=0)
-
- canvas = plt.get_current_fig_manager().canvas
- canvas.draw()
- img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
- plt.close()
- return img
-
-
-prefix = ''
-resolution = (256,256, 96)
-samples=100
-labels = []
-
-loss=[]
-loss_colors=[]
-loss_shift = (0,0)
-loss_scale = (1,1)
-
-metric=[]
-metric_colors=[]
-metric_labels = []
-metric_shift = (0,0)
-metric_scale = (1,1)
-
-app_socket = tc.connect()
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
-
- if msg_type == 'RequestDocumentation':
- tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_metrics.__doc__)))
- if msg_type == 'SetPrefix':
- prefix=tc.decode_strings(msg_data)[0]
-
- if msg_type == 'SetResolution':
- resolution = tc.decode_ints(msg_data)
-
- if msg_type == 'NumSamples':
- samples = tc.decode_ints(msg_data)[0]
- if msg_type == 'SetLabels':
- labels=tc.decode_strings(msg_data)
-
- if msg_type == 'ClearLoss':
- loss=[]
- if msg_type == 'AppendLoss':
- loss.append(tc.decode_floats(msg_data))
- if msg_type == 'SetLossColors':
- loss_colors=tc.decode_strings(msg_data)
- if msg_type == 'SetLossShift':
- loss_shift = tc.decode_floats(msg_data)
- if msg_type == 'SetLossScale':
- loss_scale = tc.decode_floats(msg_data)
-
- if msg_type == 'ClearMetric':
- metric=[]
- if msg_type == 'AppendMetric':
- metric.append(tc.decode_floats(msg_data))
- if msg_type == 'SetMetricColors':
- metric_colors=tc.decode_strings(msg_data)
- if msg_type == 'SetMetricShift':
- metric_shift = tc.decode_floats(msg_data)
- if msg_type == 'SetMetricScale':
- metric_scale = tc.decode_floats(msg_data)
-
- if msg_type == 'Render':
- if resolution[0]>0 and resolution[1]>0:
- img=plot_metrics(prefix,resolution[0:2],resolution[2],samples,labels,loss,loss_colors,loss_shift,loss_scale,metric,metric_colors,metric_shift,metric_scale)
- tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
-
- if msg_type == 'Exit':
- break
+import torchstudio.tcpcodec as tc
+import inspect
+import sys
+import os
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+from matplotlib.ticker import MaxNLocator
+import PIL
+
+def plot_metrics(prefix, size, dpi, samples=100, labels=[],
+ loss=[], loss_colors=[], loss_shift=(0,0), loss_scale=(1,1),
+ metric=[], metric_colors=[], metric_shift=(0,0), metric_scale=(1,1)):
+ """Metrics Plot
+
+ Usage:
+ Drag: pan
+ Scroll: zoom vertically
+ """
+ #set up matplotlib renderer, style, figure and axis
+ mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/
+ plt.style.use('dark_background')
+ plt.rcParams.update({'font.size': 7})
+
+ fig, [ax1, ax2] = plt.subplots(1 if size[0]>size[1] else 2, 2 if size[0]>size[1] else 1, figsize=(size[0]/dpi, size[1]/dpi), dpi=dpi)
+
+ #LOSS
+ ax1.set_title(prefix+"Loss")
+
+ #fit
+ loss_xmin=0
+ loss_xmax=samples
+ loss_ymin=0
+ loss_ymax=1
+ for l in loss:
+ loss_xmax=max(loss_xmax,len(l))
+# if(len(l)>0):
+# loss_ymax=max(loss_ymax,max(l))
+
+# #shift
+# render_size=(loss_xmax-loss_xmin,loss_ymax-loss_ymin)
+# loss_xmin-=loss_shift[0]/loss_scale[0]*render_size[0]
+# loss_xmax-=loss_shift[0]/loss_scale[0]*render_size[0]
+# loss_ymin-=loss_shift[1]/loss_scale[1]*render_size[1]
+# loss_ymax-=loss_shift[1]/loss_scale[1]*render_size[1]
+
+# #scale
+# render_center=(loss_xmin+render_size[0]/2,loss_ymin+render_size[1]/2)
+# loss_xmin=render_center[0]-(render_size[0]/loss_scale[0]/2)
+# loss_xmax=render_center[0]+(render_size[0]/loss_scale[0]/2)
+# loss_ymin=render_center[1]-(render_size[1]/loss_scale[1]/2)
+# loss_ymax=render_center[1]+(render_size[1]/loss_scale[1]/2)
+
+# loss_xmin=max(0,loss_xmin)
+# loss_ymin=max(0,loss_ymin)
+
+ loss_ymin-=loss_shift[1]/loss_scale[1]
+ loss_ymax-=loss_shift[1]/loss_scale[1]
+ loss_ymax=loss_ymax/loss_scale[1]
+
+ ax1.axis(xmin=loss_xmin,xmax=loss_xmax,ymin=loss_ymin,ymax=loss_ymax)
+ ax1.spines['top'].set_visible(False)
+ ax1.spines['right'].set_visible(False)
+ ax1.spines['left'].set_color('#707070')
+ ax1.spines['bottom'].set_color('#707070')
+ for i in range(len(loss)):
+ ax1.plot(loss[i],label=str(i) if i>=len(labels) else labels[i],color=loss_colors[i%len(loss_colors)])
+ if labels and loss and loss[0]:
+ ax1.legend(bbox_to_anchor=(1, 1), loc='upper right', ncol=1, prop={'size': 8})
+ ax1.grid(color = '#303030')
+ ax1.xaxis.set_major_locator(MaxNLocator(nbins='auto', integer=True))
+ ax1.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
+
+ #METRIC
+ ax2.set_title(prefix+"Metric")
+
+ #fit
+ metric_xmin=0
+ metric_xmax=samples
+ metric_ymin=0
+ metric_ymax=1
+ for m in metric:
+ metric_xmax=max(metric_xmax,len(m))
+
+# #shift
+# render_size=(metric_xmax-metric_xmin,metric_ymax-metric_ymin)
+# metric_xmin-=metric_shift[0]/metric_scale[0]*render_size[0]
+# metric_xmax-=metric_shift[0]/metric_scale[0]*render_size[0]
+# metric_ymin-=metric_shift[1]/metric_scale[1]*render_size[1]
+# metric_ymax-=metric_shift[1]/metric_scale[1]*render_size[1]
+
+# #scale
+# render_center=(metric_xmin+render_size[0]/2,metric_ymin+render_size[1]/2)
+# metric_xmin=render_center[0]-(render_size[0]/metric_scale[0]/2)
+# metric_xmax=render_center[0]+(render_size[0]/metric_scale[0]/2)
+# metric_ymin=render_center[1]-(render_size[1]/metric_scale[1]/2)
+# metric_ymax=render_center[1]+(render_size[1]/metric_scale[1]/2)
+
+# metric_xmin=max(0,metric_xmin)
+
+ metric_ymin-=metric_shift[1]/metric_scale[1]
+ metric_ymax-=metric_shift[1]/metric_scale[1]
+ metric_ymin=(metric_ymin-metric_ymax)/metric_scale[1]+metric_ymax
+
+ ax2.axis(xmin=metric_xmin,xmax=metric_xmax,ymin=metric_ymin,ymax=metric_ymax)
+ ax2.spines['top'].set_visible(False)
+ ax2.spines['right'].set_visible(False)
+ ax2.spines['left'].set_color('#707070')
+ ax2.spines['bottom'].set_color('#707070')
+ for i in range(len(metric)):
+ ax2.plot(metric[i],color=metric_colors[i%len(metric_colors)])
+ ax2.grid(color = '#303030')
+ ax2.xaxis.set_major_locator(MaxNLocator(nbins='auto', integer=True))
+
+ plt.tight_layout(pad=0)
+
+ canvas = plt.get_current_fig_manager().canvas
+ canvas.draw()
+ img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
+ plt.close()
+ return img
+
+
+prefix = ''
+resolution = (256,256, 96)
+samples=100
+labels = []
+
+loss=[]
+loss_colors=[]
+loss_shift = (0,0)
+loss_scale = (1,1)
+
+metric=[]
+metric_colors=[]
+metric_labels = []
+metric_shift = (0,0)
+metric_scale = (1,1)
+
+app_socket = tc.connect()
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+
+ if msg_type == 'RequestDocumentation':
+ tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_metrics.__doc__)))
+ if msg_type == 'SetPrefix':
+ prefix=tc.decode_strings(msg_data)[0]
+
+ if msg_type == 'SetResolution':
+ resolution = tc.decode_ints(msg_data)
+
+ if msg_type == 'NumSamples':
+ samples = tc.decode_ints(msg_data)[0]
+ if msg_type == 'SetLabels':
+ labels=tc.decode_strings(msg_data)
+
+ if msg_type == 'ClearLoss':
+ loss=[]
+ if msg_type == 'AppendLoss':
+ loss.append(tc.decode_floats(msg_data))
+ if msg_type == 'SetLossColors':
+ loss_colors=tc.decode_strings(msg_data)
+ if msg_type == 'SetLossShift':
+ loss_shift = tc.decode_floats(msg_data)
+ if msg_type == 'SetLossScale':
+ loss_scale = tc.decode_floats(msg_data)
+
+ if msg_type == 'ClearMetric':
+ metric=[]
+ if msg_type == 'AppendMetric':
+ metric.append(tc.decode_floats(msg_data))
+ if msg_type == 'SetMetricColors':
+ metric_colors=tc.decode_strings(msg_data)
+ if msg_type == 'SetMetricShift':
+ metric_shift = tc.decode_floats(msg_data)
+ if msg_type == 'SetMetricScale':
+ metric_scale = tc.decode_floats(msg_data)
+
+ if msg_type == 'Render':
+ if resolution[0]>0 and resolution[1]>0:
+ img=plot_metrics(prefix,resolution[0:2],resolution[2],samples,labels,loss,loss_colors,loss_shift,loss_scale,metric,metric_colors,metric_shift,metric_scale)
+ tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
+
+ if msg_type == 'Exit':
+ break
diff --git a/torchstudio/modelbuild.py b/torchstudio/modelbuild.py
index ad9a079..5f04c3e 100644
--- a/torchstudio/modelbuild.py
+++ b/torchstudio/modelbuild.py
@@ -1,246 +1,246 @@
-#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490
-import os
-os.environ['KMP_DUPLICATE_LIB_OK']='True'
-
-import sys
-print("Loading PyTorch...\n", file=sys.stderr)
-
-import torch
-import torch.fx
-from torch.fx.passes.shape_prop import ShapeProp
-from torch.fx.graph_module import GraphModule
-import torchstudio.tcpcodec as tc
-from torchstudio.modules import safe_exec
-import sys
-import os
-import io
-import re
-import graphviz
-import linecache
-import inspect
-
-#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error
-import ssl
-ssl._create_default_https_context = ssl._create_unverified_context
-
-original_path=sys.path
-original_dir=os.getcwd()
-
-level=0
-max_depth=0
-
-app_socket = tc.connect()
-print("Build script connected\n", file=sys.stderr)
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
-
- if msg_type == 'SetCurrentDir':
- new_dir=tc.decode_strings(msg_data)[0]
- sys.path=original_path
- os.chdir(original_dir)
- if new_dir:
- sys.path.append(new_dir)
- os.chdir(new_dir)
-
- if msg_type == 'SetDataDir':
- data_dir=tc.decode_strings(msg_data)[0]
- torch.hub.set_dir(data_dir)
-
- if msg_type == 'SetModelCode':
- model_code=tc.decode_strings(msg_data)[0]
-
- #create a module space for the model definition
- #see https://stackoverflow.com/questions/5122465/can-i-fake-a-package-or-at-least-a-module-in-python-for-testing-purposes/27476659#27476659
- from types import ModuleType
- modelmodule = ModuleType("modelmodule")
- modelmodule.__file__ = modelmodule.__name__ + ".py"
- sys.modules[modelmodule.__name__] = modelmodule
-
- error_msg, model_env = safe_exec(model_code, context=vars(modelmodule), output=vars(modelmodule), description='model definition')
- if error_msg is not None or 'model' not in model_env:
- print("Unknown model definition error" if error_msg is None else error_msg, file=sys.stderr)
-
- if msg_type == 'InputTensorsID':
- input_tensors = tc.decode_torch_tensors(msg_data)
- for i, tensor in enumerate(input_tensors):
- input_tensors[i]=torch.unsqueeze(tensor, 0) #add batch dimension
-
- if msg_type == 'OutputTensorsID':
- output_tensors = tc.decode_torch_tensors(msg_data)
- for i, tensor in enumerate(output_tensors):
- output_tensors[i]=torch.unsqueeze(tensor, 0) #add batch dimension
-
- if msg_type == 'SetLabels':
- labels=tc.decode_strings(msg_data)
-
- if msg_type == 'Build': #generate the torchscript, graph, and suggest hyperparameters
- if 'model' in model_env and input_tensors and output_tensors:
- print("Building model...\n", file=sys.stderr)
-
- build_mode=tc.decode_strings(msg_data)[0]
-
- buffer=io.BytesIO()
- torchscript_model=None
- if build_mode=='package': #packaging
- with torch.package.PackageExporter(buffer) as exp:
- intern_list=[]
- for path in os.listdir():
- if path.endswith(".py") and os.path.isfile(path):
- intern_list.append(path[:-3]+".**")
- if os.path.isdir(path):
- intern_list.append(path+".**")
- exp.extern('**',exclude=intern_list)
- exp.intern(intern_list)
- exp.save_source_string(modelmodule.__name__, model_code)
- exp.save_pickle('model', 'model.pkl', modelmodule.model)
- elif build_mode=='script': #scripting
- #monkey patch linecache.getlines so that inspect.getsource called by torch.jit.script can work with a module coming from a string and not a file
- def monkey_patch(filename, module_globals=None):
- if filename == '':
- return model_code.splitlines(keepends=True)
- else:
- return getlines(filename, module_globals)
- getlines = linecache.getlines
- linecache.getlines = monkey_patch
- error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':modelmodule.model}, description='model scripting')
- linecache.getlines = getlines
- else: #tracing
- error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':modelmodule.model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing')
-
- if error_msg:
- print(error_msg, file=sys.stderr)
- else:
- if torchscript_model:
- torch.jit.save(torchscript_model,buffer)
- tc.send_msg(app_socket, 'TorchScriptData', buffer.getvalue())
- else:
- tc.send_msg(app_socket, 'PackageData', buffer.getvalue())
-
- print("Building graph...\n", file=sys.stderr)
-
- level=0
- max_depth=0
-
- while level<=max_depth:
- class LevelTracer(torch.fx.Tracer):
- def is_leaf_module(self, m, qualname):
- depth=re.sub(r'.[0-9]+', '', qualname).count('.')
- if super().is_leaf_module(m, qualname)==False:
- depth=depth+1
- global max_depth
- max_depth=max(max_depth,depth)
- if depth>max_depth-level:
- return True
- else:
- return super().is_leaf_module(m, qualname)
-
- def level_trace(root):
- tracer = LevelTracer()
- graph = tracer.trace(root)
- name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
- return GraphModule(tracer.root, graph, name)
-
- error_msg, gm = safe_exec(level_trace,(model_env['model'],), description='model graph')
- if error_msg or gm is None:
- print("Unknown model graph error" if error_msg is None else error_msg, file=sys.stderr)
- else:
- modules = dict(gm.named_modules())
- ShapeProp(gm).propagate(*input_tensors)
-
- parsed_nodes={}
- for rank, node in enumerate(gm.graph.nodes):
- id=node.name
- name=node.name
- inputs=[str(i) for i in list(node.all_input_nodes)]
- output_dtype=''
- output_shape=''
- if 'tensor_meta' in node.meta:
- if type(node.meta['tensor_meta']) is tuple or type(node.meta['tensor_meta']) is list:
- for tensor_meta in node.meta['tensor_meta']:
- output_dtype+=str(tensor_meta.dtype)+' '
- output_shape+=','.join([str(i) for i in list(tensor_meta.shape)[1:]])+' '
- output_dtype=output_dtype[:-1]
- output_shape=output_shape[:-1]
- else:
- output_dtype=str(node.meta['tensor_meta'].dtype)
- output_shape=','.join([str(i) for i in list(node.meta['tensor_meta'].shape)[1:]])
-
- if node.op == 'placeholder':
- node_type='input'
- op_module=''
- op=''
- params=''
- elif node.op == 'call_module':
- node_type='module'
- name=re.sub('\.([0-9]+)', r'[\1]', node.target)
- op_module=modules[node.target].__module__
- op_module='torch.nn' #prefer this shortcut for all modules
- op=modules[node.target].__class__.__name__
- params=modules[node.target].extra_repr()
- elif node.op == 'call_function':
- node_type='function'
- op_module=node.target.__module__ if node.target.__module__ is not None else "torch"
- op_module='operator' if op_module=='_operator' else op_module
- op=node.target.__name__
- params_list=[str(x) for x in node.args]
- params_list.extend([f'{key}={value}' for key, value in node._kwargs.items()])
- params=', '.join(params_list)
- elif node.op == 'call_method':
- node_type='function'
- op_module="torch"
- op=node.target
- params_list=[str(x) for x in node.args]
- params_list.extend([f'{key}={value}' for key, value in node._kwargs.items()])
- params=', '.join(params_list)
- elif node.op == 'output':
- node_type='output'
- op_module=''
- op=''
- params=''
- for input_node in node.all_input_nodes:
- input_op=""
- if input_node.op == 'call_module':
- input_op=modules[input_node.target].__class__.__name__
- elif input_node.op == 'call_function':
- input_op=input_node.target.__name__
- else:
- node_type='unknown'
- op_module=''
- op=''
- params=''
-
- if node_type=='output' and len(inputs)>1:
- for i, input in enumerate(inputs):
- parsed_nodes[id+"_"+str(i)]={'name':name+"["+str(i)+"]", 'type':node_type, 'op_module':op_module, 'op':op, 'params':params, 'output_dtype':output_dtype, 'output_shape':output_shape, 'inputs':inputs}
- else:
- parsed_nodes[id]={'name':name, 'type':node_type, 'op_module':op_module, 'op':op, 'params':params, 'output_dtype':output_dtype, 'output_shape':output_shape, 'inputs':inputs}
- tc.send_msg(app_socket, 'GraphData', bytes(str(parsed_nodes),'utf-8'))
- level+=1
- tc.send_msg(app_socket, 'GraphDataEnd')
-
- print("Model built ("+format(sum(p.numel() for p in model_env['model'].parameters() if p.requires_grad), ',d')+" parameters)") #from https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=robin_lobel
-
- #suggest loss names
- loss=[]
- for i, tensor in enumerate(output_tensors):
- if len(tensor.shape)==1 and "int" in str(tensor.dtype):
- #multiclass crossentropy classification
- loss.append("CrossEntropy")
- elif len(tensor.shape)==2 and tensor.shape[1]==len(labels):
- #multiclass multilabel classification
- loss.append("BinaryCrossEntropy")
- else:
- #default back to MSE for everything else
- loss.append("MeanSquareError")
-
- #suggest metric names
- metric=[]
- for tensor in output_tensors:
- metric.append("Accuracy")
-
- tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([128,0,100,1,1,1]))
- tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['Adam','Step']))
-
- if msg_type == 'Exit':
- break
-
+#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490
+import os
+os.environ['KMP_DUPLICATE_LIB_OK']='True'
+
+import sys
+print("Loading PyTorch...\n", file=sys.stderr)
+
+import torch
+import torch.fx
+from torch.fx.passes.shape_prop import ShapeProp
+from torch.fx.graph_module import GraphModule
+import torchstudio.tcpcodec as tc
+from torchstudio.modules import safe_exec
+import sys
+import os
+import io
+import re
+import graphviz
+import linecache
+import inspect
+
+#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+
+original_path=sys.path
+original_dir=os.getcwd()
+
+level=0
+max_depth=0
+
+app_socket = tc.connect()
+print("Build script connected\n", file=sys.stderr)
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+
+ if msg_type == 'SetCurrentDir':
+ new_dir=tc.decode_strings(msg_data)[0]
+ sys.path=original_path
+ os.chdir(original_dir)
+ if new_dir:
+ sys.path.append(new_dir)
+ os.chdir(new_dir)
+
+ if msg_type == 'SetDataDir':
+ data_dir=tc.decode_strings(msg_data)[0]
+ torch.hub.set_dir(data_dir)
+
+ if msg_type == 'SetModelCode':
+ model_code=tc.decode_strings(msg_data)[0]
+
+ #create a module space for the model definition
+ #see https://stackoverflow.com/questions/5122465/can-i-fake-a-package-or-at-least-a-module-in-python-for-testing-purposes/27476659#27476659
+ from types import ModuleType
+ modelmodule = ModuleType("modelmodule")
+ modelmodule.__file__ = modelmodule.__name__ + ".py"
+ sys.modules[modelmodule.__name__] = modelmodule
+
+ error_msg, model_env = safe_exec(model_code, context=vars(modelmodule), output=vars(modelmodule), description='model definition')
+ if error_msg is not None or 'model' not in model_env:
+ print("Unknown model definition error" if error_msg is None else error_msg, file=sys.stderr)
+
+ if msg_type == 'InputTensorsID':
+ input_tensors = tc.decode_torch_tensors(msg_data)
+ for i, tensor in enumerate(input_tensors):
+ input_tensors[i]=torch.unsqueeze(tensor, 0) #add batch dimension
+
+ if msg_type == 'OutputTensorsID':
+ output_tensors = tc.decode_torch_tensors(msg_data)
+ for i, tensor in enumerate(output_tensors):
+ output_tensors[i]=torch.unsqueeze(tensor, 0) #add batch dimension
+
+ if msg_type == 'SetLabels':
+ labels=tc.decode_strings(msg_data)
+
+ if msg_type == 'Build': #generate the torchscript, graph, and suggest hyperparameters
+ if 'model' in model_env and input_tensors and output_tensors:
+ print("Building model...\n", file=sys.stderr)
+
+ build_mode=tc.decode_strings(msg_data)[0]
+
+ buffer=io.BytesIO()
+ torchscript_model=None
+ if build_mode=='package': #packaging
+ with torch.package.PackageExporter(buffer) as exp:
+ intern_list=[]
+ for path in os.listdir():
+ if path.endswith(".py") and os.path.isfile(path):
+ intern_list.append(path[:-3]+".**")
+ if os.path.isdir(path):
+ intern_list.append(path+".**")
+ exp.extern('**',exclude=intern_list)
+ exp.intern(intern_list)
+ exp.save_source_string(modelmodule.__name__, model_code)
+ exp.save_pickle('model', 'model.pkl', modelmodule.model)
+ elif build_mode=='script': #scripting
+ #monkey patch linecache.getlines so that inspect.getsource called by torch.jit.script can work with a module coming from a string and not a file
+ def monkey_patch(filename, module_globals=None):
+ if filename == '':
+ return model_code.splitlines(keepends=True)
+ else:
+ return getlines(filename, module_globals)
+ getlines = linecache.getlines
+ linecache.getlines = monkey_patch
+ error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':modelmodule.model}, description='model scripting')
+ linecache.getlines = getlines
+ else: #tracing
+ error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':modelmodule.model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing')
+
+ if error_msg:
+ print(error_msg, file=sys.stderr)
+ else:
+ if torchscript_model:
+ torch.jit.save(torchscript_model,buffer)
+ tc.send_msg(app_socket, 'TorchScriptData', buffer.getvalue())
+ else:
+ tc.send_msg(app_socket, 'PackageData', buffer.getvalue())
+
+ print("Building graph...\n", file=sys.stderr)
+
+ level=0
+ max_depth=0
+
+ while level<=max_depth:
+ class LevelTracer(torch.fx.Tracer):
+ def is_leaf_module(self, m, qualname):
+ depth=re.sub(r'.[0-9]+', '', qualname).count('.')
+ if super().is_leaf_module(m, qualname)==False:
+ depth=depth+1
+ global max_depth
+ max_depth=max(max_depth,depth)
+ if depth>max_depth-level:
+ return True
+ else:
+ return super().is_leaf_module(m, qualname)
+
+ def level_trace(root):
+ tracer = LevelTracer()
+ graph = tracer.trace(root)
+ name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
+ return GraphModule(tracer.root, graph, name)
+
+ error_msg, gm = safe_exec(level_trace,(model_env['model'],), description='model graph')
+ if error_msg or gm is None:
+ print("Unknown model graph error" if error_msg is None else error_msg, file=sys.stderr)
+ else:
+ modules = dict(gm.named_modules())
+ ShapeProp(gm).propagate(*input_tensors)
+
+ parsed_nodes={}
+ for rank, node in enumerate(gm.graph.nodes):
+ id=node.name
+ name=node.name
+ inputs=[str(i) for i in list(node.all_input_nodes)]
+ output_dtype=''
+ output_shape=''
+ if 'tensor_meta' in node.meta:
+ if type(node.meta['tensor_meta']) is tuple or type(node.meta['tensor_meta']) is list:
+ for tensor_meta in node.meta['tensor_meta']:
+ output_dtype+=str(tensor_meta.dtype)+' '
+ output_shape+=','.join([str(i) for i in list(tensor_meta.shape)[1:]])+' '
+ output_dtype=output_dtype[:-1]
+ output_shape=output_shape[:-1]
+ else:
+ output_dtype=str(node.meta['tensor_meta'].dtype)
+ output_shape=','.join([str(i) for i in list(node.meta['tensor_meta'].shape)[1:]])
+
+ if node.op == 'placeholder':
+ node_type='input'
+ op_module=''
+ op=''
+ params=''
+ elif node.op == 'call_module':
+ node_type='module'
+ name=re.sub('\.([0-9]+)', r'[\1]', node.target)
+ op_module=modules[node.target].__module__
+ op_module='torch.nn' #prefer this shortcut for all modules
+ op=modules[node.target].__class__.__name__
+ params=modules[node.target].extra_repr()
+ elif node.op == 'call_function':
+ node_type='function'
+ op_module=node.target.__module__ if node.target.__module__ is not None else "torch"
+ op_module='operator' if op_module=='_operator' else op_module
+ op=node.target.__name__
+ params_list=[str(x) for x in node.args]
+ params_list.extend([f'{key}={value}' for key, value in node._kwargs.items()])
+ params=', '.join(params_list)
+ elif node.op == 'call_method':
+ node_type='function'
+ op_module="torch"
+ op=node.target
+ params_list=[str(x) for x in node.args]
+ params_list.extend([f'{key}={value}' for key, value in node._kwargs.items()])
+ params=', '.join(params_list)
+ elif node.op == 'output':
+ node_type='output'
+ op_module=''
+ op=''
+ params=''
+ for input_node in node.all_input_nodes:
+ input_op=""
+ if input_node.op == 'call_module':
+ input_op=modules[input_node.target].__class__.__name__
+ elif input_node.op == 'call_function':
+ input_op=input_node.target.__name__
+ else:
+ node_type='unknown'
+ op_module=''
+ op=''
+ params=''
+
+ if node_type=='output' and len(inputs)>1:
+ for i, input in enumerate(inputs):
+ parsed_nodes[id+"_"+str(i)]={'name':name+"["+str(i)+"]", 'type':node_type, 'op_module':op_module, 'op':op, 'params':params, 'output_dtype':output_dtype, 'output_shape':output_shape, 'inputs':inputs}
+ else:
+ parsed_nodes[id]={'name':name, 'type':node_type, 'op_module':op_module, 'op':op, 'params':params, 'output_dtype':output_dtype, 'output_shape':output_shape, 'inputs':inputs}
+ tc.send_msg(app_socket, 'GraphData', bytes(str(parsed_nodes),'utf-8'))
+ level+=1
+ tc.send_msg(app_socket, 'GraphDataEnd')
+
+ print("Model built ("+format(sum(p.numel() for p in model_env['model'].parameters() if p.requires_grad), ',d')+" parameters)") #from https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=robin_lobel
+
+ #suggest loss names
+ loss=[]
+ for i, tensor in enumerate(output_tensors):
+ if len(tensor.shape)==1 and "int" in str(tensor.dtype):
+ #multiclass crossentropy classification
+ loss.append("CrossEntropy")
+ elif len(tensor.shape)==2 and tensor.shape[1]==len(labels):
+ #multiclass multilabel classification
+ loss.append("BinaryCrossEntropy")
+ else:
+ #default back to MSE for everything else
+ loss.append("MeanSquareError")
+
+ #suggest metric names
+ metric=[]
+ for tensor in output_tensors:
+ metric.append("Accuracy")
+
+ tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([128,0,100,1,1]))
+ tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['Adam','Step']))
+
+ if msg_type == 'Exit':
+ break
+
diff --git a/torchstudio/models/unet1d.py b/torchstudio/models/unet1d.py
index ba778c0..78a0baa 100644
--- a/torchstudio/models/unet1d.py
+++ b/torchstudio/models/unet1d.py
@@ -1,143 +1,143 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py
-def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
- sequence = []
- for i in range(conv_per_block):
- sequence.append(nn.Conv1d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
- sequence.append(nn.ReLU(inplace=True))
- if batch_norm:
- #BatchNorm best after ReLU:
- #https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
- #https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999
- #https://github.com/cvjena/cnn-models/issues/3
- sequence.append(nn.BatchNorm1d(out_channels))
- return nn.Sequential(*sequence)
-
-class DownConv(nn.Module):
- def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):
- super().__init__()
-
- self.pooling = pooling
-
- self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)
-
- if self.pooling:
- if not conv_downscaling:
- self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
- else:
- self.pool = nn.Conv1d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)
-
- def forward(self, x):
- x = self.block(x)
- before_pool = x
- if self.pooling:
- x = self.pool(x)
- return x, before_pool
-
-
-class UpConv(nn.Module):
- def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
- add_merging, conv_upscaling):
- super().__init__()
-
- self.add_merging = add_merging
-
- if not conv_upscaling:
- self.upconv = nn.ConvTranspose1d(in_channels,out_channels,kernel_size=2,stride=2)
- else:
- self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
- nn.Conv1d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
-
-
- self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)
-
- def forward(self, from_down, from_up):
- from_up = self.upconv(from_up)
- if not self.add_merging:
- x = torch.cat((from_up, from_down), 1)
- else:
- x = from_up + from_down
- x = self.block(x)
- return x
-
-
-class UNet1D(nn.Module):
- """ `UNet` class is based on https://arxiv.org/abs/1505.04597
- UNet is a convolutional encoder-decoder neural network.
-
- This 1D variant is inspired by 1D Unet are inspired by the
- Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf )
- Default parameters correspond to the Wave UNet.
- Convolutions use padding to preserve the original size.
-
- Args:
- in_channels: number of channels in the input tensor.
- out_channels: number of channels in the output tensor.
- feature_channels: number of channels in the first and last hidden feature layer.
- depth: number of levels
- conv_per_block: number of convolutions per level block
- kernel_size: kernel size for all block convolutions
- batch_norm: add a batch norm after ReLU
- conv_upscaling: use a nearest upsize+conv instead of transposed convolution
- conv_downscaling: use a strided convolution instead of maxpooling
- add_merging: merge layers from different levels using a add instead of a concat
- """
-
- def __init__(self, in_channels=1, out_channels=1, feature_channels=24,
- depth=12, conv_per_block=1, kernel_size=5, batch_norm=False,
- conv_upscaling=False, conv_downscaling=False, add_merging=False):
- super().__init__()
-
- self.out_channels = out_channels
- self.in_channels = in_channels
- self.feature_channels = feature_channels
- self.depth = depth
-
- self.down_convs = []
- self.up_convs = []
-
- # create the encoder pathway and add to a list
- for i in range(depth):
- ins = self.in_channels if i == 0 else outs
- outs = self.feature_channels*(i+1)
- pooling = True if i < depth-1 else False
-
- down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
- conv_downscaling, pooling=pooling)
- self.down_convs.append(down_conv)
-
- # create the decoder pathway and add to a list
- # - careful! decoding only requires depth-1 blocks
- for i in range(depth-1):
- ins = outs
- outs = ins - self.feature_channels
- up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
- conv_upscaling=conv_upscaling, add_merging=add_merging)
- self.up_convs.append(up_conv)
-
- self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)
-
- # add the list of modules to current module
- self.down_convs = nn.ModuleList(self.down_convs)
- self.up_convs = nn.ModuleList(self.up_convs)
-
- def forward(self, x):
- encoder_outs = []
-
- # encoder pathway, save outputs for merging
- for i, module in enumerate(self.down_convs):
- x, before_pool = module(x)
- encoder_outs.append(before_pool)
-
- for i, module in enumerate(self.up_convs):
- before_pool = encoder_outs[-(i+2)]
- x = module(before_pool, x)
-
- # No softmax is used. This means you need to use
- # nn.CrossEntropyLoss is your training script,
- # as this module includes a softmax already.
- x = self.conv_final(x)
- return x
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py
+def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
+ sequence = []
+ for i in range(conv_per_block):
+ sequence.append(nn.Conv1d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
+ sequence.append(nn.ReLU(inplace=True))
+ if batch_norm:
+ #BatchNorm best after ReLU:
+ #https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
+ #https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999
+ #https://github.com/cvjena/cnn-models/issues/3
+ sequence.append(nn.BatchNorm1d(out_channels))
+ return nn.Sequential(*sequence)
+
+class DownConv(nn.Module):
+ def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):
+ super().__init__()
+
+ self.pooling = pooling
+
+ self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)
+
+ if self.pooling:
+ if not conv_downscaling:
+ self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
+ else:
+ self.pool = nn.Conv1d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)
+
+ def forward(self, x):
+ x = self.block(x)
+ before_pool = x
+ if self.pooling:
+ x = self.pool(x)
+ return x, before_pool
+
+
+class UpConv(nn.Module):
+ def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
+ add_merging, conv_upscaling):
+ super().__init__()
+
+ self.add_merging = add_merging
+
+ if not conv_upscaling:
+ self.upconv = nn.ConvTranspose1d(in_channels,out_channels,kernel_size=2,stride=2)
+ else:
+ self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
+ nn.Conv1d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
+
+
+ self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)
+
+ def forward(self, from_down, from_up):
+ from_up = self.upconv(from_up)
+ if not self.add_merging:
+ x = torch.cat((from_up, from_down), 1)
+ else:
+ x = from_up + from_down
+ x = self.block(x)
+ return x
+
+
+class UNet1D(nn.Module):
+ """ `UNet` class is based on https://arxiv.org/abs/1505.04597
+ UNet is a convolutional encoder-decoder neural network.
+
+ This 1D variant is inspired by 1D Unet are inspired by the
+ Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf )
+ Default parameters correspond to the Wave UNet.
+ Convolutions use padding to preserve the original size.
+
+ Args:
+ in_channels: number of channels in the input tensor.
+ out_channels: number of channels in the output tensor.
+ feature_channels: number of channels in the first and last hidden feature layer.
+ depth: number of levels
+ conv_per_block: number of convolutions per level block
+ kernel_size: kernel size for all block convolutions
+ batch_norm: add a batch norm after ReLU
+ conv_upscaling: use a nearest upsize+conv instead of transposed convolution
+ conv_downscaling: use a strided convolution instead of maxpooling
+ add_merging: merge layers from different levels using a add instead of a concat
+ """
+
+ def __init__(self, in_channels=1, out_channels=1, feature_channels=24,
+ depth=12, conv_per_block=1, kernel_size=5, batch_norm=False,
+ conv_upscaling=False, conv_downscaling=False, add_merging=False):
+ super().__init__()
+
+ self.out_channels = out_channels
+ self.in_channels = in_channels
+ self.feature_channels = feature_channels
+ self.depth = depth
+
+ self.down_convs = []
+ self.up_convs = []
+
+ # create the encoder pathway and add to a list
+ for i in range(depth):
+ ins = self.in_channels if i == 0 else outs
+ outs = self.feature_channels*(i+1)
+ pooling = True if i < depth-1 else False
+
+ down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
+ conv_downscaling, pooling=pooling)
+ self.down_convs.append(down_conv)
+
+ # create the decoder pathway and add to a list
+ # - careful! decoding only requires depth-1 blocks
+ for i in range(depth-1):
+ ins = outs
+ outs = ins - self.feature_channels
+ up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
+ conv_upscaling=conv_upscaling, add_merging=add_merging)
+ self.up_convs.append(up_conv)
+
+ self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)
+
+ # add the list of modules to current module
+ self.down_convs = nn.ModuleList(self.down_convs)
+ self.up_convs = nn.ModuleList(self.up_convs)
+
+ def forward(self, x):
+ encoder_outs = []
+
+ # encoder pathway, save outputs for merging
+ for i, module in enumerate(self.down_convs):
+ x, before_pool = module(x)
+ encoder_outs.append(before_pool)
+
+ for i, module in enumerate(self.up_convs):
+ before_pool = encoder_outs[-(i+2)]
+ x = module(before_pool, x)
+
+ # No softmax is used. This means you need to use
+ # nn.CrossEntropyLoss is your training script,
+ # as this module includes a softmax already.
+ x = self.conv_final(x)
+ return x
diff --git a/torchstudio/models/unet2d.py b/torchstudio/models/unet2d.py
index ca1c91e..1b58ae0 100644
--- a/torchstudio/models/unet2d.py
+++ b/torchstudio/models/unet2d.py
@@ -1,166 +1,166 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py
-def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
- sequence = []
- for i in range(conv_per_block):
- sequence.append(nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
- sequence.append(nn.ReLU(inplace=True))
- if batch_norm:
- #BatchNorm best after ReLU:
- #https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
- #https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999
- #https://github.com/cvjena/cnn-models/issues/3
- sequence.append(nn.BatchNorm2d(out_channels))
- return nn.Sequential(*sequence)
-
-class DownConv(nn.Module):
- def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):
- super().__init__()
-
- self.in_channels=in_channels
- self.out_channels=out_channels
- self.conv_per_block=conv_per_block
- self.kernel_size=kernel_size
- self.batch_norm=batch_norm
- self.conv_downscaling=conv_downscaling
- self.pooling = pooling
-
- self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)
-
- if self.pooling:
- if not conv_downscaling:
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
- else:
- self.pool = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)
-
- def forward(self, x):
- x = self.block(x)
- before_pool = x
- if self.pooling:
- x = self.pool(x)
- return x, before_pool
-
- def extra_repr(self):
- # (Optional)Set the extra information about this module. You can test
- # it by printing an object of this class.
- return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, conv_downscaling={}, pooling={}'.format(
- self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.conv_downscaling, self.pooling
- )
-
-
-class UpConv(nn.Module):
- def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
- add_merging, conv_upscaling):
- super().__init__()
-
- self.in_channels=in_channels
- self.out_channels=out_channels
- self.conv_per_block=conv_per_block
- self.kernel_size=kernel_size
- self.batch_norm=batch_norm
- self.add_merging = add_merging
- self.conv_upscaling = conv_upscaling
-
- if not conv_upscaling:
- self.upconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2)
- else:
- self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
- nn.Conv2d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
-
-
- self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)
-
- def forward(self, from_down, from_up):
- from_up = self.upconv(from_up)
- if not self.add_merging:
- x = torch.cat((from_up, from_down), 1)
- else:
- x = from_up + from_down
- x = self.block(x)
- return x
-
- def extra_repr(self):
- # (Optional)Set the extra information about this module. You can test
- # it by printing an object of this class.
- return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, add_merging={}, conv_upscaling={}'.format(
- self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.add_merging, self.conv_upscaling
- )
-
-class UNet2D(nn.Module):
- """ `UNet` class is based on https://arxiv.org/abs/1505.04597
- UNet is a convolutional encoder-decoder neural network.
-
- Default parameters correspond to the original UNet, except
- convolutions use padding to preserve the original size.
-
- Args:
- in_channels: number of channels in the input tensor.
- out_channels: number of channels in the output tensor.
- feature_channels: number of channels in the first and last hidden feature layer.
- depth: number of levels
- conv_per_block: number of convolutions per level block
- kernel_size: kernel size for all block convolutions
- batch_norm: add a batch norm after ReLU
- conv_upscaling: use a nearest upscale+conv instead of transposed convolution
- conv_downscaling: use a strided convolution instead of maxpooling
- add_merging: merge layers from different levels using a add instead of a concat
- """
-
- def __init__(self, in_channels=1, out_channels=2, feature_channels=64,
- depth=5, conv_per_block=2, kernel_size=3, batch_norm=False,
- conv_upscaling=False, conv_downscaling=False, add_merging=False):
- super().__init__()
-
- self.out_channels = out_channels
- self.in_channels = in_channels
- self.feature_channels = feature_channels
- self.depth = depth
-
- self.down_convs = []
- self.up_convs = []
-
- # create the encoder pathway and add to a list
- for i in range(depth):
- ins = self.in_channels if i == 0 else outs
- outs = self.feature_channels*(2**i)
- pooling = True if i < depth-1 else False
-
- down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
- conv_downscaling, pooling=pooling)
- self.down_convs.append(down_conv)
-
- # create the decoder pathway and add to a list
- # - careful! decoding only requires depth-1 blocks
- for i in range(depth-1):
- ins = outs
- outs = ins // 2
- up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
- conv_upscaling=conv_upscaling, add_merging=add_merging)
- self.up_convs.append(up_conv)
-
- self.conv_final = nn.Conv2d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)
-
- # add the list of modules to current module
- self.down_convs = nn.ModuleList(self.down_convs)
- self.up_convs = nn.ModuleList(self.up_convs)
-
- def forward(self, x):
- encoder_outs = []
-
- # encoder pathway, save outputs for merging
- for i, module in enumerate(self.down_convs):
- x, before_pool = module(x)
- encoder_outs.append(before_pool)
-
- for i, module in enumerate(self.up_convs):
- before_pool = encoder_outs[-(i+2)]
- x = module(before_pool, x)
-
- # No softmax is used. This means you need to use
- # nn.CrossEntropyLoss is your training script,
- # as this module includes a softmax already.
- x = self.conv_final(x)
- return x
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py
+def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
+ sequence = []
+ for i in range(conv_per_block):
+ sequence.append(nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
+ sequence.append(nn.ReLU(inplace=True))
+ if batch_norm:
+ #BatchNorm best after ReLU:
+ #https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
+ #https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999
+ #https://github.com/cvjena/cnn-models/issues/3
+ sequence.append(nn.BatchNorm2d(out_channels))
+ return nn.Sequential(*sequence)
+
+class DownConv(nn.Module):
+ def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):
+ super().__init__()
+
+ self.in_channels=in_channels
+ self.out_channels=out_channels
+ self.conv_per_block=conv_per_block
+ self.kernel_size=kernel_size
+ self.batch_norm=batch_norm
+ self.conv_downscaling=conv_downscaling
+ self.pooling = pooling
+
+ self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)
+
+ if self.pooling:
+ if not conv_downscaling:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ else:
+ self.pool = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)
+
+ def forward(self, x):
+ x = self.block(x)
+ before_pool = x
+ if self.pooling:
+ x = self.pool(x)
+ return x, before_pool
+
+ def extra_repr(self):
+ # (Optional)Set the extra information about this module. You can test
+ # it by printing an object of this class.
+ return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, conv_downscaling={}, pooling={}'.format(
+ self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.conv_downscaling, self.pooling
+ )
+
+
+class UpConv(nn.Module):
+ def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
+ add_merging, conv_upscaling):
+ super().__init__()
+
+ self.in_channels=in_channels
+ self.out_channels=out_channels
+ self.conv_per_block=conv_per_block
+ self.kernel_size=kernel_size
+ self.batch_norm=batch_norm
+ self.add_merging = add_merging
+ self.conv_upscaling = conv_upscaling
+
+ if not conv_upscaling:
+ self.upconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2)
+ else:
+ self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
+ nn.Conv2d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
+
+
+ self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)
+
+ def forward(self, from_down, from_up):
+ from_up = self.upconv(from_up)
+ if not self.add_merging:
+ x = torch.cat((from_up, from_down), 1)
+ else:
+ x = from_up + from_down
+ x = self.block(x)
+ return x
+
+ def extra_repr(self):
+ # (Optional)Set the extra information about this module. You can test
+ # it by printing an object of this class.
+ return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, add_merging={}, conv_upscaling={}'.format(
+ self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.add_merging, self.conv_upscaling
+ )
+
+class UNet2D(nn.Module):
+ """ `UNet` class is based on https://arxiv.org/abs/1505.04597
+ UNet is a convolutional encoder-decoder neural network.
+
+ Default parameters correspond to the original UNet, except
+ convolutions use padding to preserve the original size.
+
+ Args:
+ in_channels: number of channels in the input tensor.
+ out_channels: number of channels in the output tensor.
+ feature_channels: number of channels in the first and last hidden feature layer.
+ depth: number of levels
+ conv_per_block: number of convolutions per level block
+ kernel_size: kernel size for all block convolutions
+ batch_norm: add a batch norm after ReLU
+ conv_upscaling: use a nearest upscale+conv instead of transposed convolution
+ conv_downscaling: use a strided convolution instead of maxpooling
+ add_merging: merge layers from different levels using a add instead of a concat
+ """
+
+ def __init__(self, in_channels=1, out_channels=2, feature_channels=64,
+ depth=5, conv_per_block=2, kernel_size=3, batch_norm=False,
+ conv_upscaling=False, conv_downscaling=False, add_merging=False):
+ super().__init__()
+
+ self.out_channels = out_channels
+ self.in_channels = in_channels
+ self.feature_channels = feature_channels
+ self.depth = depth
+
+ self.down_convs = []
+ self.up_convs = []
+
+ # create the encoder pathway and add to a list
+ for i in range(depth):
+ ins = self.in_channels if i == 0 else outs
+ outs = self.feature_channels*(2**i)
+ pooling = True if i < depth-1 else False
+
+ down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
+ conv_downscaling, pooling=pooling)
+ self.down_convs.append(down_conv)
+
+ # create the decoder pathway and add to a list
+ # - careful! decoding only requires depth-1 blocks
+ for i in range(depth-1):
+ ins = outs
+ outs = ins // 2
+ up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
+ conv_upscaling=conv_upscaling, add_merging=add_merging)
+ self.up_convs.append(up_conv)
+
+ self.conv_final = nn.Conv2d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)
+
+ # add the list of modules to current module
+ self.down_convs = nn.ModuleList(self.down_convs)
+ self.up_convs = nn.ModuleList(self.up_convs)
+
+ def forward(self, x):
+ encoder_outs = []
+
+ # encoder pathway, save outputs for merging
+ for i, module in enumerate(self.down_convs):
+ x, before_pool = module(x)
+ encoder_outs.append(before_pool)
+
+ for i, module in enumerate(self.up_convs):
+ before_pool = encoder_outs[-(i+2)]
+ x = module(before_pool, x)
+
+ # No softmax is used. This means you need to use
+ # nn.CrossEntropyLoss is your training script,
+ # as this module includes a softmax already.
+ x = self.conv_final(x)
+ return x
diff --git a/torchstudio/modeltrain.py b/torchstudio/modeltrain.py
index 3b720b8..2bf6558 100644
--- a/torchstudio/modeltrain.py
+++ b/torchstudio/modeltrain.py
@@ -1,360 +1,364 @@
-#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490
-import os
-os.environ['KMP_DUPLICATE_LIB_OK']='True'
-
-import sys
-
-print("Loading PyTorch...\n", file=sys.stderr)
-
-import torch
-from torch.utils.data import Dataset
-import torchstudio.tcpcodec as tc
-from torchstudio.modules import safe_exec
-import os
-import sys
-import io
-import tempfile
-from tqdm.auto import tqdm
-from collections.abc import Iterable
-
-
-class CachedDataset(Dataset):
- def __init__(self, disk_cache=False):
- self.reset(disk_cache)
-
- def add_sample(self, data):
- if self.disk_cache:
- file=tempfile.TemporaryFile(prefix='torchstudio.'+str(len(self.index))+'.') #guaranteed to be deleted on win/mac/linux: https://bugs.python.org/issue4928
- file.write(data)
- self.index.append(file)
- else:
- self.index.append(tc.decode_torch_tensors(data))
-
- def reset(self, disk_cache=False):
- self.index = []
- self.disk_cache=disk_cache
-
- def __len__(self):
- return len(self.index)
-
- def __getitem__(self, id):
- if id<0 or id>=len(self):
- raise IndexError
-
- if self.disk_cache:
- file=self.index[id]
- file.seek(0)
- sample=tc.decode_torch_tensors(file.read())
- else:
- sample=self.index[id]
- return sample
-
-def deepcopy_cpu(value):
- if isinstance(value, torch.Tensor):
- value = value.to("cpu")
- return value
- elif isinstance(value, dict):
- return {k: deepcopy_cpu(v) for k, v in value.items()}
- elif isinstance(value, Iterable):
- return type(value)(deepcopy_cpu(v) for v in value)
- else:
- return value
-
-modules_valid=True
-
-train_dataset = CachedDataset()
-valid_dataset = CachedDataset()
-train_bar = None
-
-model = None
-
-app_socket = tc.connect()
-print("Training script connected\n", file=sys.stderr)
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
-
- if msg_type == 'SetDevice':
- print("Setting device...\n", file=sys.stderr)
- device_id=tc.decode_strings(msg_data)[0]
- device = torch.device(device_id)
- pin_memory = True if 'cuda' in device_id else False
-
- if msg_type == 'SetTorchScriptModel' and modules_valid:
- if msg_data:
- print("Setting torchscript model...\n", file=sys.stderr)
- buffer=io.BytesIO(msg_data)
- model = torch.jit.load(buffer)
-
- if msg_type == 'SetPackageModel' and modules_valid:
- if msg_data:
- print("Setting package model...\n", file=sys.stderr)
- buffer=io.BytesIO(msg_data)
- model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl')
-
- if msg_type == 'SetModelState' and modules_valid:
- if model is not None:
- if msg_data:
- buffer=io.BytesIO(msg_data)
- model.load_state_dict(torch.load(buffer))
- model.to(device)
-
- if msg_type == 'SetLossCodes' and modules_valid:
- print("Setting loss code...\n", file=sys.stderr)
- loss_definitions=tc.decode_strings(msg_data)
- criteria = []
- for definition in loss_definitions:
- error_msg, loss_env = safe_exec(definition, description='loss definition')
- if error_msg is not None or 'loss' not in loss_env:
- print("Unknown loss definition error" if error_msg is None else error_msg, file=sys.stderr)
- modules_valid=False
- tc.send_msg(app_socket, 'TrainingError')
- break
- else:
- criteria.append(loss_env['loss'])
-
- if msg_type == 'SetMetricCodes' and modules_valid:
- print("Setting metrics code...\n", file=sys.stderr)
- metric_definitions=tc.decode_strings(msg_data)
- metrics = []
- for definition in metric_definitions:
- error_msg, metric_env = safe_exec(definition, description='metric definition')
- if error_msg is not None or 'metric' not in metric_env:
- print("Unknown metric definition error" if error_msg is None else error_msg, file=sys.stderr)
- modules_valid=False
- tc.send_msg(app_socket, 'TrainingError')
- break
- else:
- metrics.append(metric_env['metric'])
-
- if msg_type == 'SetOptimizerCode' and modules_valid:
- print("Setting optimizer code...\n", file=sys.stderr)
- error_msg, optimizer_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='optimizer definition')
- if error_msg is not None or 'optimizer' not in optimizer_env:
- print("Unknown optimizer definition error" if error_msg is None else error_msg, file=sys.stderr)
- modules_valid=False
- tc.send_msg(app_socket, 'TrainingError')
- else:
- optimizer = optimizer_env['optimizer']
-
- if msg_type == 'SetOptimizerState' and modules_valid:
- if msg_data:
- buffer=io.BytesIO(msg_data)
- optimizer.load_state_dict(torch.load(buffer))
-
- if msg_type == 'SetSchedulerCode' and modules_valid:
- print("Setting scheduler code...\n", file=sys.stderr)
- error_msg, scheduler_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='scheduler definition')
- if error_msg is not None or 'scheduler' not in scheduler_env:
- print("Unknown scheduler definition error" if error_msg is None else error_msg, file=sys.stderr)
- modules_valid=False
- tc.send_msg(app_socket, 'TrainingError')
- else:
- scheduler = scheduler_env['scheduler']
-
- if msg_type == 'SetHyperParametersValues' and modules_valid: #set other hyperparameters values
- batch_size, shuffle, epochs, early_stop, monitor_metric, restore_best = tc.decode_ints(msg_data)
- shuffle=True if shuffle==1 else False
- early_stop=True if early_stop==1 else False
- monitor_metric=True if monitor_metric==1 else False
- restore_best=True if restore_best==1 else False
-
- if msg_type == 'StartTrainingServer' and modules_valid:
- print("Caching...\n", file=sys.stderr)
-
- sshaddress, sshport, username, password, keydata = tc.decode_strings(msg_data)
-
- training_server, address = tc.generate_server()
-
- if sshaddress and sshport and username:
- import socket
- import paramiko
- import torchstudio.sshtunnel as sshtunnel
-
- if not password:
- password=None
- if not keydata:
- pkey=None
- else:
- import io
- keybuffer=io.StringIO(keydata)
- pkey=paramiko.RSAKey.from_private_key(keybuffer)
-
- sshclient = paramiko.SSHClient()
- sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5)
-
- reverse_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ReverseTunnel, 'localhost', 0, 'localhost', int(address[1]))
- address[1]=str(reverse_tunnel.lport)
-
- tc.send_msg(app_socket, 'TrainingServerRequestingAllSamples', tc.encode_strings(address))
- dataset_socket=tc.start_server(training_server)
- train_dataset.reset()
- valid_dataset.reset()
-
- while True:
- dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)
-
- if dataset_msg_type == 'NumSamples':
- num_samples=tc.decode_ints(dataset_msg_data)[0]
- pbar=tqdm(total=num_samples, desc='Caching...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
-
- if dataset_msg_type == 'InputTensorsID' and modules_valid:
- input_tensors_id = tc.decode_ints(dataset_msg_data)
-
- if dataset_msg_type == 'OutputTensorsID' and modules_valid:
- output_tensors_id = tc.decode_ints(dataset_msg_data)
-
- if dataset_msg_type == 'TrainingSample':
- train_dataset.add_sample(dataset_msg_data)
- pbar.update(1)
-
- if dataset_msg_type == 'ValidationSample':
- valid_dataset.add_sample(dataset_msg_data)
- pbar.update(1)
-
- if dataset_msg_type == 'DoneSending':
- pbar.close()
- tc.send_msg(dataset_socket, 'DoneReceiving')
- dataset_socket.close()
- training_server.close()
- if sshaddress and sshport and username:
- sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong
- break
-
- train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
- valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size, shuffle=False, pin_memory=pin_memory)
-
- if msg_type == 'StartTraining' and modules_valid:
- print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr)
-
- if msg_type == 'TrainOneEpoch' and modules_valid:
- #training
- model.train()
- train_loss = 0
- train_metrics = []
- for metric in metrics:
- metric.reset()
- for batch_id, tensors in enumerate(train_loader):
- inputs = [tensors[i].to(device) for i in input_tensors_id]
- targets = [tensors[i].to(device) for i in output_tensors_id]
- optimizer.zero_grad()
- outputs = model(*inputs)
- outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
- loss = 0
- for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
- loss = loss + criterion(output, target)
- loss.backward()
- optimizer.step()
- train_loss += loss.item() * inputs[0].size(0)
-
- with torch.set_grad_enabled(False):
- for output, target, metric in zip(outputs, targets, metrics):
- metric.update(output, target)
-
- train_loss = train_loss/len(train_dataset)
- train_metrics = 0
- for metric in metrics:
- train_metrics = train_metrics+metric.compute().item()
- train_metrics/=len(metrics)
- scheduler.step()
-
- #validation
- model.eval()
- valid_loss = 0
- valid_metrics = []
- for metric in metrics:
- metric.reset()
- with torch.set_grad_enabled(False):
- for batch_id, tensors in enumerate(valid_loader):
- inputs = [tensors[i].to(device) for i in input_tensors_id]
- targets = [tensors[i].to(device) for i in output_tensors_id]
- outputs = model(*inputs)
- outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
- loss = 0
- for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
- loss = loss + criterion(output, target)
- valid_loss += loss.item() * inputs[0].size(0)
-
- for output, target, metric in zip(outputs, targets, metrics):
- metric.update(output, target)
-
- valid_loss = valid_loss/len(valid_dataset)
- valid_metrics = 0
- for metric in metrics:
- valid_metrics = valid_metrics+metric.compute().item()
- valid_metrics/=len(metrics)
-
- tc.send_msg(app_socket, 'TrainingLoss', tc.encode_floats(train_loss))
- tc.send_msg(app_socket, 'ValidationLoss', tc.encode_floats(valid_loss))
- tc.send_msg(app_socket, 'TrainingMetric', tc.encode_floats(train_metrics))
- tc.send_msg(app_socket, 'ValidationMetric', tc.encode_floats(valid_metrics))
-
- buffer=io.BytesIO()
- torch.save(deepcopy_cpu(model.state_dict()), buffer)
- tc.send_msg(app_socket, 'ModelState', buffer.getvalue())
-
- buffer=io.BytesIO()
- torch.save(deepcopy_cpu(optimizer.state_dict()), buffer)
- tc.send_msg(app_socket, 'OptimizerState', buffer.getvalue())
-
- tc.send_msg(app_socket, 'Trained')
-
- #create train_bar only after first successful training to avoid ghost progress message after an error
- if train_bar is not None:
- train_bar.bar_format='{desc} epoch {n_fmt} | {remaining} left |{rate_fmt}\n\n'
- else:
- train_bar = tqdm(total=epochs, desc='Training...', bar_format='{desc} epoch '+str(scheduler.last_epoch)+'\n\n', initial=scheduler.last_epoch-1)
- train_bar.update(1)
-
- if msg_type == 'StopTraining' and modules_valid:
- if train_bar is not None:
- train_bar.close()
- train_bar=None
- print("Training stopped at epoch "+str(scheduler.last_epoch-1), file=sys.stderr)
-
- if msg_type == 'SetInputTensors' or msg_type == 'InferTensors':
- input_tensors = tc.decode_torch_tensors(msg_data)
- for i, tensor in enumerate(input_tensors):
- input_tensors[i]=torch.unsqueeze(tensor, 0).to(device) #add batch dimension
-
- if msg_type == 'InferTensors':
- if model is not None:
- with torch.set_grad_enabled(False):
- model.eval()
- output_tensors=model(*input_tensors)
- output_tensors=[output.cpu() for output in output_tensors]
- tc.send_msg(app_socket, 'InferedTensors', tc.encode_torch_tensors(output_tensors))
-
- if msg_type == 'SaveTorchScript':
- path, mode = tc.decode_strings(msg_data)
- if "torch.jit" in str(type(model)):
- torch.jit.save(model, path) #already a torchscript, save as is
- print("Export complete")
- else:
- if mode=="trace":
- error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing')
- else:
- error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':model}, description='model scripting')
- if error_msg:
- print("Error exporting:", error_msg, file=sys.stderr)
- else:
- torch.jit.save(torchscript_model, path)
- print("Export complete")
-
- if msg_type == 'SaveONNX':
- error_msg=None
- torchscript_model=model
- if not "torch.jit" in str(type(model)):
- error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing')
- if error_msg:
- print("Error exporting:", error_msg, file=sys.stderr)
- else:
- error_msg, torchscript_model = safe_exec(torch.onnx.export,{'model':torchscript_model, 'args':input_tensors, 'f':tc.decode_strings(msg_data)[0], 'opset_version':12})
- if error_msg:
- print("Error exporting:", error_msg, file=sys.stderr)
- else:
- print("Export complete")
-
- if msg_type == 'Exit':
- break
-
+#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490
+import os
+os.environ['KMP_DUPLICATE_LIB_OK']='True'
+
+import sys
+
+print("Loading PyTorch...\n", file=sys.stderr)
+
+import torch
+from torch.utils.data import Dataset
+import torchstudio.tcpcodec as tc
+from torchstudio.modules import safe_exec
+import os
+import sys
+import io
+import tempfile
+from tqdm.auto import tqdm
+from collections.abc import Iterable
+
+
+class CachedDataset(Dataset):
+ def __init__(self, disk_cache=False):
+ self.reset(disk_cache)
+
+ def add_sample(self, data):
+ if self.disk_cache:
+ file=tempfile.TemporaryFile(prefix='torchstudio.'+str(len(self.index))+'.') #guaranteed to be deleted on win/mac/linux: https://bugs.python.org/issue4928
+ file.write(data)
+ self.index.append(file)
+ else:
+ self.index.append(tc.decode_torch_tensors(data))
+
+ def reset(self, disk_cache=False):
+ self.index = []
+ self.disk_cache=disk_cache
+
+ def __len__(self):
+ return len(self.index)
+
+ def __getitem__(self, id):
+ if id<0 or id>=len(self):
+ raise IndexError
+
+ if self.disk_cache:
+ file=self.index[id]
+ file.seek(0)
+ sample=tc.decode_torch_tensors(file.read())
+ else:
+ sample=self.index[id]
+ return sample
+
+def deepcopy_cpu(value):
+ if isinstance(value, torch.Tensor):
+ value = value.to("cpu")
+ return value
+ elif isinstance(value, dict):
+ return {k: deepcopy_cpu(v) for k, v in value.items()}
+ elif isinstance(value, Iterable):
+ return type(value)(deepcopy_cpu(v) for v in value)
+ else:
+ return value
+
+modules_valid=True
+
+train_dataset = CachedDataset()
+valid_dataset = CachedDataset()
+train_bar = None
+
+model = None
+
+app_socket = tc.connect()
+print("Training script connected\n", file=sys.stderr)
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+
+ if msg_type == 'SetDevice':
+ print("Setting device...\n", file=sys.stderr)
+ device_id=tc.decode_strings(msg_data)[0]
+ device = torch.device(device_id)
+ pin_memory = True if 'cuda' in device_id else False
+
+ if msg_type == 'SetTorchScriptModel' and modules_valid:
+ if msg_data:
+ print("Setting torchscript model...\n", file=sys.stderr)
+ buffer=io.BytesIO(msg_data)
+ model = torch.jit.load(buffer)
+
+ if msg_type == 'SetPackageModel' and modules_valid:
+ if msg_data:
+ print("Setting package model...\n", file=sys.stderr)
+ buffer=io.BytesIO(msg_data)
+ model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl')
+
+ if msg_type == 'SetModelState' and modules_valid:
+ if model is not None:
+ if msg_data:
+ buffer=io.BytesIO(msg_data)
+ model.load_state_dict(torch.load(buffer))
+ model.to(device)
+
+ if msg_type == 'SetLossCodes' and modules_valid:
+ print("Setting loss code...\n", file=sys.stderr)
+ loss_definitions=tc.decode_strings(msg_data)
+ criteria = []
+ for definition in loss_definitions:
+ error_msg, loss_env = safe_exec(definition, description='loss definition')
+ if error_msg is not None or 'loss' not in loss_env:
+ print("Unknown loss definition error" if error_msg is None else error_msg, file=sys.stderr)
+ modules_valid=False
+ tc.send_msg(app_socket, 'TrainingError')
+ break
+ else:
+ criteria.append(loss_env['loss'])
+
+ if msg_type == 'SetMetricCodes' and modules_valid:
+ print("Setting metrics code...\n", file=sys.stderr)
+ metric_definitions=tc.decode_strings(msg_data)
+ metrics = []
+ for definition in metric_definitions:
+ error_msg, metric_env = safe_exec(definition, description='metric definition')
+ if error_msg is not None or 'metric' not in metric_env:
+ print("Unknown metric definition error" if error_msg is None else error_msg, file=sys.stderr)
+ modules_valid=False
+ tc.send_msg(app_socket, 'TrainingError')
+ break
+ else:
+ metrics.append(metric_env['metric'])
+
+ if msg_type == 'SetOptimizerCode' and modules_valid:
+ print("Setting optimizer code...\n", file=sys.stderr)
+ error_msg, optimizer_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='optimizer definition')
+ if error_msg is not None or 'optimizer' not in optimizer_env:
+ print("Unknown optimizer definition error" if error_msg is None else error_msg, file=sys.stderr)
+ modules_valid=False
+ tc.send_msg(app_socket, 'TrainingError')
+ else:
+ optimizer = optimizer_env['optimizer']
+
+ if msg_type == 'SetOptimizerState' and modules_valid:
+ if msg_data:
+ buffer=io.BytesIO(msg_data)
+ optimizer.load_state_dict(torch.load(buffer))
+
+ if msg_type == 'SetSchedulerCode' and modules_valid:
+ print("Setting scheduler code...\n", file=sys.stderr)
+ error_msg, scheduler_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='scheduler definition')
+ if error_msg is not None or 'scheduler' not in scheduler_env:
+ print("Unknown scheduler definition error" if error_msg is None else error_msg, file=sys.stderr)
+ modules_valid=False
+ tc.send_msg(app_socket, 'TrainingError')
+ else:
+ scheduler = scheduler_env['scheduler']
+
+ if msg_type == 'SetHyperParametersValues' and modules_valid: #set other hyperparameters values
+ batch_size, shuffle, epochs, early_stop, restore_best = tc.decode_ints(msg_data)
+ shuffle=True if shuffle==1 else False
+ early_stop=True if early_stop==1 else False
+ restore_best=True if restore_best==1 else False
+
+ if msg_type == 'StartTrainingServer' and modules_valid:
+ print("Caching...\n", file=sys.stderr)
+
+ sshaddress, sshport, username, password, keydata = tc.decode_strings(msg_data)
+
+ training_server, address = tc.generate_server()
+
+ if sshaddress and sshport and username:
+ import socket
+ import paramiko
+ import torchstudio.sshtunnel as sshtunnel
+
+ if not password:
+ password=None
+ if not keydata:
+ pkey=None
+ else:
+ import io
+ keybuffer=io.StringIO(keydata)
+ pkey=paramiko.RSAKey.from_private_key(keybuffer)
+
+ sshclient = paramiko.SSHClient()
+ sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5)
+
+ reverse_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ReverseTunnel, 'localhost', 0, 'localhost', int(address[1]))
+ address[1]=str(reverse_tunnel.lport)
+
+ tc.send_msg(app_socket, 'TrainingServerRequestingAllSamples', tc.encode_strings(address))
+ dataset_socket=tc.start_server(training_server)
+ train_dataset.reset()
+ valid_dataset.reset()
+
+ while True:
+ dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)
+
+ if dataset_msg_type == 'NumSamples':
+ num_samples=tc.decode_ints(dataset_msg_data)[0]
+ pbar=tqdm(total=num_samples, desc='Caching...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
+
+ if dataset_msg_type == 'InputTensorsID' and modules_valid:
+ input_tensors_id = tc.decode_ints(dataset_msg_data)
+
+ if dataset_msg_type == 'OutputTensorsID' and modules_valid:
+ output_tensors_id = tc.decode_ints(dataset_msg_data)
+
+ if dataset_msg_type == 'TrainingSample':
+ train_dataset.add_sample(dataset_msg_data)
+ pbar.update(1)
+
+ if dataset_msg_type == 'ValidationSample':
+ valid_dataset.add_sample(dataset_msg_data)
+ pbar.update(1)
+
+ if dataset_msg_type == 'DoneSending':
+ pbar.close()
+ tc.send_msg(dataset_socket, 'DoneReceiving')
+ dataset_socket.close()
+ training_server.close()
+ if sshaddress and sshport and username:
+ sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong
+ break
+
+ train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
+ valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size, shuffle=False, pin_memory=pin_memory)
+
+ if msg_type == 'StartTraining' and modules_valid:
+ print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr)
+
+ if msg_type == 'TrainOneEpoch' and modules_valid:
+ #training
+ model.train()
+ train_loss = 0
+ train_metrics = []
+ for metric in metrics:
+ metric.reset()
+ for batch_id, tensors in enumerate(train_loader):
+ inputs = [tensors[i].to(device) for i in input_tensors_id]
+ targets = [tensors[i].to(device) for i in output_tensors_id]
+ optimizer.zero_grad()
+ outputs = model(*inputs)
+ outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
+ loss = 0
+ for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
+ loss = loss + criterion(output, target)
+ loss.backward()
+ optimizer.step()
+ train_loss += loss.item() * inputs[0].size(0)
+
+ with torch.set_grad_enabled(False):
+ for output, target, metric in zip(outputs, targets, metrics):
+ metric.update(output, target)
+
+ train_loss = train_loss/len(train_dataset)
+ train_metrics = 0
+ for metric in metrics:
+ train_metrics = train_metrics+metric.compute().item()
+ train_metrics/=len(metrics)
+ scheduler.step()
+
+ #validation
+ model.eval()
+ valid_loss = 0
+ valid_metrics = []
+ for metric in metrics:
+ metric.reset()
+ with torch.set_grad_enabled(False):
+ for batch_id, tensors in enumerate(valid_loader):
+ inputs = [tensors[i].to(device) for i in input_tensors_id]
+ targets = [tensors[i].to(device) for i in output_tensors_id]
+ outputs = model(*inputs)
+ outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
+ loss = 0
+ for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
+ loss = loss + criterion(output, target)
+ valid_loss += loss.item() * inputs[0].size(0)
+
+ for output, target, metric in zip(outputs, targets, metrics):
+ metric.update(output, target)
+
+ valid_loss = valid_loss/len(valid_dataset)
+ valid_metrics = 0
+ for metric in metrics:
+ valid_metrics = valid_metrics+metric.compute().item()
+ valid_metrics/=len(metrics)
+
+ tc.send_msg(app_socket, 'TrainingLoss', tc.encode_floats(train_loss))
+ tc.send_msg(app_socket, 'ValidationLoss', tc.encode_floats(valid_loss))
+ tc.send_msg(app_socket, 'TrainingMetric', tc.encode_floats(train_metrics))
+ tc.send_msg(app_socket, 'ValidationMetric', tc.encode_floats(valid_metrics))
+
+ buffer=io.BytesIO()
+ torch.save(deepcopy_cpu(model.state_dict()), buffer)
+ tc.send_msg(app_socket, 'ModelState', buffer.getvalue())
+
+ buffer=io.BytesIO()
+ torch.save(deepcopy_cpu(optimizer.state_dict()), buffer)
+ tc.send_msg(app_socket, 'OptimizerState', buffer.getvalue())
+
+ tc.send_msg(app_socket, 'Trained')
+
+ #create train_bar only after first successful training to avoid ghost progress message after an error
+ if train_bar is not None:
+ train_bar.bar_format='{desc} epoch {n_fmt} |{rate_fmt}\n\n'
+ else:
+ train_bar = tqdm(total=epochs, desc='Training...', bar_format='{desc} epoch '+str(scheduler.last_epoch)+'\n\n', initial=scheduler.last_epoch-1)
+ train_bar.update(1)
+
+ if msg_type == 'StopTraining' and modules_valid:
+ if train_bar is not None:
+ train_bar.close()
+ train_bar=None
+ print("Training stopped at epoch "+str(scheduler.last_epoch-1), file=sys.stderr)
+
+ if msg_type == 'SetInputTensors' or msg_type == 'InferTensors':
+ input_tensors = tc.decode_torch_tensors(msg_data)
+ for i, tensor in enumerate(input_tensors):
+ input_tensors[i]=torch.unsqueeze(tensor, 0).to(device) #add batch dimension
+
+ if msg_type == 'InferTensors':
+ if model is not None:
+ with torch.set_grad_enabled(False):
+ model.eval()
+ output_tensors=model(*input_tensors)
+ output_tensors=[output.cpu() for output in output_tensors]
+ tc.send_msg(app_socket, 'InferedTensors', tc.encode_torch_tensors(output_tensors))
+
+ if msg_type == 'SaveWeights':
+ path = tc.decode_strings(msg_data)[0]
+ torch.save(deepcopy_cpu(model.state_dict()), path)
+ print("Export complete")
+
+ if msg_type == 'SaveTorchScript':
+ path, mode = tc.decode_strings(msg_data)
+ if "torch.jit" in str(type(model)):
+ torch.jit.save(model, path) #already a torchscript, save as is
+ print("Export complete")
+ else:
+ if mode=="trace":
+ error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing')
+ else:
+ error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':model}, description='model scripting')
+ if error_msg:
+ print("Error exporting:", error_msg, file=sys.stderr)
+ else:
+ torch.jit.save(torchscript_model, path)
+ print("Export complete")
+
+ if msg_type == 'SaveONNX':
+ error_msg=None
+ torchscript_model=model
+ if not "torch.jit" in str(type(model)):
+ error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing')
+ if error_msg:
+ print("Error exporting:", error_msg, file=sys.stderr)
+ else:
+ error_msg, torchscript_model = safe_exec(torch.onnx.export,{'model':torchscript_model, 'args':input_tensors, 'f':tc.decode_strings(msg_data)[0], 'opset_version':12})
+ if error_msg:
+ print("Error exporting:", error_msg, file=sys.stderr)
+ else:
+ print("Export complete")
+
+ if msg_type == 'Exit':
+ break
+
diff --git a/torchstudio/parametersplot.py b/torchstudio/parametersplot.py
index 70366ca..dd8930e 100644
--- a/torchstudio/parametersplot.py
+++ b/torchstudio/parametersplot.py
@@ -1,170 +1,170 @@
-import torchstudio.tcpcodec as tc
-import inspect
-import sys
-import os
-
-import matplotlib as mpl
-import matplotlib.pyplot as plt
-from matplotlib.ticker import MaxNLocator
-import PIL
-
-def sorted(l,reverse=False):
- floats=True
- for x in l:
- try:
- float(x)
- except:
- floats=False
- break
- l.sort(key=float if floats else None,reverse=reverse)
- return l
-
-#inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib
-def plot_parameters(size, dpi,
- parameters=[], #parameters is a list of parameters
- values=[], #values is a list of list containing string values
- order=[]): #sorting order for each parameter(1 or -1)
- """Parameters Plot
-
- Usage:
- Click: invert parameter sorting order
- """
- #set up matplotlib renderer, style, figure and axis
- mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/
- plt.style.use('dark_background')
- plt.rcParams.update({'font.size': 7})
-
- if len(parameters)<2:
- parameters=['Name', 'Validation\nMetric']
-
-# parameters=['Name', 'feature_channels', 'depth', 'Metric Value']
-# values=[['Model 1','32','3','.95'],['Model 2','24','4','.9'],['Model 3','16','3','.98'],['Model 4','16','3']]
-
- if len(order)1 else .5 for j in range(len(param_values[i]))])
- ax.set_yticklabels(param_values[i])
- #first parameter is the model name, keep the set_ticks
- axes[0].yaxis.set_tick_params(width=1)
- #last parameter is the metric, let the colorbar do the metric
- axes[-1].yaxis.set_ticks_position('none')
- axes[-1].set_yticklabels([])
- axes[-1].spines['left'].set_visible(False)
-
- #set the colorbar for the metric
- if param_values[-1]:
- max_metric=min_metric=float(param_values[-1][0])
- for metric_value in param_values[-1]:
- min_metric=min(min_metric,float(metric_value))
- max_metric=max(max_metric,float(metric_value))
- else:
- max_metric=min_metric=0
-
- cmap = plt.get_cmap('viridis') # 'viridis' or 'rainbow'
- sc = host.scatter([0,0], [0,0], s=[0,0], c=[min_metric, max_metric], cmap=cmap)
- cbar = fig.colorbar(sc, ax=axes[-1], pad=0)
- cbar.outline.set_visible(False)
-# cbar.set_ticks([])
-
- #set horizontal axe settings
- host.set_xlim(0, len(parameters) - 1)
- host.set_xticks(range(len(parameters)))
- host.set_xticklabels(parameters)
- host.tick_params(axis='x', which='major', pad=7)
- host.spines['right'].set_visible(False)
- host.xaxis.tick_top()
-
-
-
- from matplotlib.path import Path
- import matplotlib.patches as patches
- import numpy as np
- for tokens in values:
- values_num=[]
- for i, token in enumerate(tokens):
- if i1 else .5)
- else:
- values_num.append((float(token)-min_metric)/(max_metric-min_metric) if len(param_values[i])>1 and max_metric>min_metric else .5)
-
- # create bezier curves
- # for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one
- # at one third towards the next axis; the first and last axis have one less control vertex
- # x-coordinate of the control vertices: at each integer (for the axes) and two inbetween
- # y-coordinate: repeat every point three times, except the first and last only twice
- verts = list(zip([x for x in np.linspace(0, len(values_num) - 1, len(values_num) * 3 - 2, endpoint=True)],
- np.repeat(values_num, 3)[1:-1]))
- # for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers
- codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)]
- path = Path(verts, codes)
- patch = patches.PathPatch(path, facecolor='none', lw=1, edgecolor=cmap(values_num[-1]) if len(values_num)==len(parameters) else (0.33, 0.33, 0.33), zorder=values_num[-1] if len(values_num)==len(parameters) else -1)
- host.add_patch(patch)
-
- plt.tight_layout(pad=0)
-
- canvas = plt.get_current_fig_manager().canvas
- canvas.draw()
- img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
- plt.close()
- return img
-
-
-resolution = (256,256, 96)
-
-parameters=[]
-values=[]
-order=[]
-
-
-app_socket = tc.connect()
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
-
- if msg_type == 'RequestDocumentation':
- tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_parameters.__doc__)))
-
- if msg_type == 'SetResolution':
- resolution = tc.decode_ints(msg_data)
-
- if msg_type == 'SetParameters':
- parameters=tc.decode_strings(msg_data)
-
- if msg_type == 'ClearValues':
- values = []
- if msg_type == 'AppendValues':
- values.append(tc.decode_strings(msg_data))
-
- if msg_type == 'SetOrder':
- order=tc.decode_ints(msg_data)
-
- if msg_type == 'Render':
- if resolution[0]>0 and resolution[1]>0:
- img=plot_parameters(resolution[0:2],resolution[2],parameters,values,order)
- tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
-
- if msg_type == 'Exit':
- break
+import torchstudio.tcpcodec as tc
+import inspect
+import sys
+import os
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+from matplotlib.ticker import MaxNLocator
+import PIL
+
+def sorted(l,reverse=False):
+ floats=True
+ for x in l:
+ try:
+ float(x)
+ except:
+ floats=False
+ break
+ l.sort(key=float if floats else None,reverse=reverse)
+ return l
+
+#inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib
+def plot_parameters(size, dpi,
+ parameters=[], #parameters is a list of parameters
+ values=[], #values is a list of list containing string values
+ order=[]): #sorting order for each parameter(1 or -1)
+ """Parameters Plot
+
+ Usage:
+ Click: invert parameter sorting order
+ """
+ #set up matplotlib renderer, style, figure and axis
+ mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/
+ plt.style.use('dark_background')
+ plt.rcParams.update({'font.size': 7})
+
+ if len(parameters)<2:
+ parameters=['Name', 'Validation\nMetric']
+
+# parameters=['Name', 'feature_channels', 'depth', 'Metric Value']
+# values=[['Model 1','32','3','.95'],['Model 2','24','4','.9'],['Model 3','16','3','.98'],['Model 4','16','3']]
+
+ if len(order)1 else .5 for j in range(len(param_values[i]))])
+ ax.set_yticklabels(param_values[i])
+ #first parameter is the model name, keep the set_ticks
+ axes[0].yaxis.set_tick_params(width=1)
+ #last parameter is the metric, let the colorbar do the metric
+ axes[-1].yaxis.set_ticks_position('none')
+ axes[-1].set_yticklabels([])
+ axes[-1].spines['left'].set_visible(False)
+
+ #set the colorbar for the metric
+ if param_values[-1]:
+ max_metric=min_metric=float(param_values[-1][0])
+ for metric_value in param_values[-1]:
+ min_metric=min(min_metric,float(metric_value))
+ max_metric=max(max_metric,float(metric_value))
+ else:
+ max_metric=min_metric=0
+
+ cmap = plt.get_cmap('viridis') # 'viridis' or 'rainbow'
+ sc = host.scatter([0,0], [0,0], s=[0,0], c=[min_metric, max_metric], cmap=cmap)
+ cbar = fig.colorbar(sc, ax=axes[-1], pad=0)
+ cbar.outline.set_visible(False)
+# cbar.set_ticks([])
+
+ #set horizontal axe settings
+ host.set_xlim(0, len(parameters) - 1)
+ host.set_xticks(range(len(parameters)))
+ host.set_xticklabels(parameters)
+ host.tick_params(axis='x', which='major', pad=7)
+ host.spines['right'].set_visible(False)
+ host.xaxis.tick_top()
+
+
+
+ from matplotlib.path import Path
+ import matplotlib.patches as patches
+ import numpy as np
+ for tokens in values:
+ values_num=[]
+ for i, token in enumerate(tokens):
+ if i1 else .5)
+ else:
+ values_num.append((float(token)-min_metric)/(max_metric-min_metric) if len(param_values[i])>1 and max_metric>min_metric else .5)
+
+ # create bezier curves
+ # for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one
+ # at one third towards the next axis; the first and last axis have one less control vertex
+ # x-coordinate of the control vertices: at each integer (for the axes) and two inbetween
+ # y-coordinate: repeat every point three times, except the first and last only twice
+ verts = list(zip([x for x in np.linspace(0, len(values_num) - 1, len(values_num) * 3 - 2, endpoint=True)],
+ np.repeat(values_num, 3)[1:-1]))
+ # for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers
+ codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)]
+ path = Path(verts, codes)
+ patch = patches.PathPatch(path, facecolor='none', lw=1, edgecolor=cmap(values_num[-1]) if len(values_num)==len(parameters) else (0.33, 0.33, 0.33), zorder=values_num[-1] if len(values_num)==len(parameters) else -1)
+ host.add_patch(patch)
+
+ plt.tight_layout(pad=0)
+
+ canvas = plt.get_current_fig_manager().canvas
+ canvas.draw()
+ img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
+ plt.close()
+ return img
+
+
+resolution = (256,256, 96)
+
+parameters=[]
+values=[]
+order=[]
+
+
+app_socket = tc.connect()
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+
+ if msg_type == 'RequestDocumentation':
+ tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_parameters.__doc__)))
+
+ if msg_type == 'SetResolution':
+ resolution = tc.decode_ints(msg_data)
+
+ if msg_type == 'SetParameters':
+ parameters=tc.decode_strings(msg_data)
+
+ if msg_type == 'ClearValues':
+ values = []
+ if msg_type == 'AppendValues':
+ values.append(tc.decode_strings(msg_data))
+
+ if msg_type == 'SetOrder':
+ order=tc.decode_ints(msg_data)
+
+ if msg_type == 'Render':
+ if resolution[0]>0 and resolution[1]>0:
+ img=plot_parameters(resolution[0:2],resolution[2],parameters,values,order)
+ tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
+
+ if msg_type == 'Exit':
+ break
diff --git a/torchstudio/pythoninstall.cmd b/torchstudio/pythoninstall.cmd
index 2267b56..3ba66ef 100644
--- a/torchstudio/pythoninstall.cmd
+++ b/torchstudio/pythoninstall.cmd
@@ -1,239 +1,239 @@
-# 2>nul&@goto :BATCH
-
-#BASH
-
-SCRIPTDIR=$(cd "$(dirname "$0")"; pwd)
-(
-cd "$SCRIPTDIR"
-cd ..
-
-pythonpath="$(pwd)/python"
-channel="pytorch"
-cuda=""
-packages=""
-uninstall=""
-while [ "$1" != "" ]; do
- if [ $1 == "--path" ]; then
- shift; pythonpath=$1
- elif [ $1 == "--channel" ]; then
- shift; channel=$1
- elif [ $1 == "--cuda" ]; then
- cuda="--cuda"
- elif [ $1 == "--package" ]; then
- shift; packages+="--package $1"
- elif [ $1 == "--uninstall" ]; then
- uninstall="1"
- fi
- shift
-done
-
-if [ ! -z "$uninstall" ]; then
- echo "Uninstalling python environment..."
- rm -f *.sh
- rm -f *.tmp
- rm -f -r "$pythonpath"
- if [ $? != 0 ]; then
- echo "" 1>&2
- echo "Error while uninstalling. Check write permissions." 1>&2
- exit 1
- else
- echo ""
- echo "Uninstall complete."
- exit 0
- fi
-fi
-
-if [[ $OSTYPE == "linux"* ]]; then
- echo "Downloading, installing and setting up a linux python environment"
-elif [[ $OSTYPE == "darwin"* ]]; then
- echo "Downloading, installing and setting up a macOS python environment"
-else
- echo "Error: unsupported OS ($OSTYPE)" 1>&2
- exit 1
-fi
-
-if [ ! -z "$cuda" ]; then
- echo "This may take up to 16 minutes depending on your download speed, and up to 16 GB."
-else
- echo "This may take up to 5 minutes depending on your download speed, and up to 5 GB."
-fi
-
-echo ""
-if [[ $OSTYPE == "linux"* ]]; then
- file=Miniconda3-latest-Linux-x86_64.sh
- rm -f "$file.tmp"
- if [ -f "$file" ]; then
- echo "Python installer ($file) already downloaded"
- else
- echo "Downloading Python installer ($file)..."
- wget --show-progress --progress=bar:force:noscroll --no-check-certificate https://repo.anaconda.com/miniconda/$file -O "$file.tmp"
- if [ $? != 0 ]; then
- rm -f "$file.tmp"
- echo "" 1>&2
- echo "Error while downloading. Make sure port 80 is open." 1>&2
- exit 1
- else
- mv "$file.tmp" "$file"
- fi
- fi
-elif [[ $OSTYPE == "darwin"* ]]; then
- if [ "$(uname -m)" == "arm64" ]; then
- file=Miniconda3-latest-MacOSX-arm64.sh
- else
- file=Miniconda3-latest-MacOSX-x86_64.sh
- fi
- rm -f "$file.tmp"
- if [ -f "$file" ]; then
- echo "Python installer $file already downloaded"
- else
- echo "Downloading Python installer $file..."
- curl --insecure https://repo.anaconda.com/miniconda/$file -o "$file.tmp"
- if [ $? != 0 ]; then
- rm -f "$file.tmp"
- echo "" 1>&2
- echo "Error while downloading. Make sure port 80 is open." 1>&2
- exit 1
- else
- mv "$file.tmp" "$file"
- fi
- fi
-fi
-
-echo ""
-if [ -f "python.tmp" ]; then
- rm -f python.tmp
- rm -f -r "$pythonpath"
-fi
-if [ -d "$pythonpath" ]; then
- echo "Python already installed in $pythonpath"
-else
- echo "Installing Python in $pythonpath..."
- echo "" > python.tmp
- bash "$(pwd)/$file" -b -f -p "$pythonpath"
- rm -f python.tmp
- if [ $? != 0 ]; then
- rm -f -r "$pythonpath"
- echo "" 1>&2
- echo "Error while installing. Make sure you have write permissions." 1>&2
- exit 1
- fi
-fi
-
-PATH="$PATH;$pythonpath/bin"
-"$pythonpath/bin/python" -u -B -X utf8 -m torchstudio.pythoninstall --channel $channel $cuda $packages
-if [ $? != 0 ]; then
- echo "" 1>&2
- echo "Error while installing packages" 1>&2
- exit 1
-fi
-
-echo ""
-echo "Installation complete."
-)
-exit
-
-
-:BATCH
-
-@echo off
-setlocal
-cd /D "%~dp0"
-cd ..
-
-set pythonpath=%cd%\python
-set channel=pytorch
-set cuda=
-set packages=
-set uninstall=
-:args
-if "%~1" == "--path" (
- set pythonpath=%~2
- shift
-) else if "%~1" == "--channel" (
- set channel=%~2
- shift
-) else if "%~1" == "--cuda" (
- set cuda=--cuda
-) else if "%~1" == "--package" (
- set packages=%packages% --package %~2
- shift
-) else if "%~1" == "--uninstall" (
- set uninstall=1
-) else if "%~1" == "" (
- goto endargs
-)
-shift
-goto args
-:endargs
-
-if DEFINED uninstall (
- echo Uninstalling python environment...
- del *.exe 2>nul
- del *.tmp 2>nul
- rmdir /s /q "%pythonpath%" 2>nul
- if ERRORLEVEL 1 (
- echo. 1>&2
- echo Error while uninstalling. Check write permissions. 1>&2
- exit /B 1
- ) else (
- echo.
- echo Uninstall complete.
- exit /B 0
- )
-)
-
-echo Downloading, installing and setting up a windows python environment
-if DEFINED cuda (
- echo This may take up to 16 minutes depending on your download speed, and up to 16 GB.
-) else (
- echo This may take up to 5 minutes depending on your download speed, and up to 5 GB.
-)
-
-echo.
-set file=Miniconda3-latest-Windows-x86_64.exe
-del %file%.tmp 2>nul
-if EXIST "%file%" (
- echo Python installer %file% already downloaded
-) else (
- echo Downloading Python installer %file%...
- curl --insecure https://repo.anaconda.com/miniconda/%file% -o %file%.tmp
- if ERRORLEVEL 1 (
- del %file%.tmp 2>nul
- echo. 1>&2
- echo Error while downloading. Make sure port 80 is open. 1>&2
- exit /B 1
- ) else (
- ren %file%.tmp %file%
- )
-)
-
-echo.
-if EXIST "python.tmp" (
- del python.tmp 2>nul
- rmdir /s /q "%pythonpath%" 2>nul
-)
-if EXIST "%pythonpath%" (
- echo Python already installed in %pythonpath%
-) else (
- echo Installing Python in %pythonpath%...
- echo. > python.tmp
- %file% /S /D=%pythonpath%
- del python.tmp 2>nul
- if ERRORLEVEL 1 (
- rmdir /s /q "%pythonpath%" 2>nul
- echo. 1>&2
- echo Error while installing. Make sure you have write permissions. 1>&2
- exit /B 1
- )
-)
-
-set PATH=%PATH%;%pythonpath%;%pythonpath%\Library\mingw-w64\bin;%pythonpath%\Library\bin
-"%pythonpath%\python" -u -B -X utf8 -m torchstudio.pythoninstall --channel %channel% %cuda% %packages%
-if ERRORLEVEL 1 (
- echo. 1>&2
- echo Error while installing packages 1>&2
- exit /B 1
-)
-
-echo.
-echo Installation complete.
+# 2>nul&@goto :BATCH
+
+#BASH
+
+SCRIPTDIR=$(cd "$(dirname "$0")"; pwd)
+(
+cd "$SCRIPTDIR"
+cd ..
+
+pythonpath="$(pwd)/python"
+channel="pytorch"
+cuda=""
+packages=""
+uninstall=""
+while [ "$1" != "" ]; do
+ if [ $1 == "--path" ]; then
+ shift; pythonpath=$1
+ elif [ $1 == "--channel" ]; then
+ shift; channel=$1
+ elif [ $1 == "--cuda" ]; then
+ cuda="--cuda"
+ elif [ $1 == "--package" ]; then
+ shift; packages+="--package $1"
+ elif [ $1 == "--uninstall" ]; then
+ uninstall="1"
+ fi
+ shift
+done
+
+if [ ! -z "$uninstall" ]; then
+ echo "Uninstalling python environment..."
+ rm -f *.sh
+ rm -f *.tmp
+ rm -f -r "$pythonpath"
+ if [ $? != 0 ]; then
+ echo "" 1>&2
+ echo "Error while uninstalling. Check write permissions." 1>&2
+ exit 1
+ else
+ echo ""
+ echo "Uninstall complete."
+ exit 0
+ fi
+fi
+
+if [[ $OSTYPE == "linux"* ]]; then
+ echo "Downloading, installing and setting up a linux python environment"
+elif [[ $OSTYPE == "darwin"* ]]; then
+ echo "Downloading, installing and setting up a macOS python environment"
+else
+ echo "Error: unsupported OS ($OSTYPE)" 1>&2
+ exit 1
+fi
+
+if [ ! -z "$cuda" ]; then
+ echo "This may take up to 16 minutes depending on your download speed, and up to 16 GB."
+else
+ echo "This may take up to 5 minutes depending on your download speed, and up to 5 GB."
+fi
+
+echo ""
+if [[ $OSTYPE == "linux"* ]]; then
+ file=Miniconda3-latest-Linux-x86_64.sh
+ rm -f "$file.tmp"
+ if [ -f "$file" ]; then
+ echo "Python installer ($file) already downloaded"
+ else
+ echo "Downloading Python installer ($file)..."
+ wget --show-progress --progress=bar:force:noscroll --no-check-certificate https://repo.anaconda.com/miniconda/$file -O "$file.tmp"
+ if [ $? != 0 ]; then
+ rm -f "$file.tmp"
+ echo "" 1>&2
+ echo "Error while downloading. Make sure port 80 is open." 1>&2
+ exit 1
+ else
+ mv "$file.tmp" "$file"
+ fi
+ fi
+elif [[ $OSTYPE == "darwin"* ]]; then
+ if [ "$(uname -m)" == "arm64" ]; then
+ file=Miniconda3-latest-MacOSX-arm64.sh
+ else
+ file=Miniconda3-latest-MacOSX-x86_64.sh
+ fi
+ rm -f "$file.tmp"
+ if [ -f "$file" ]; then
+ echo "Python installer $file already downloaded"
+ else
+ echo "Downloading Python installer $file..."
+ curl --insecure https://repo.anaconda.com/miniconda/$file -o "$file.tmp"
+ if [ $? != 0 ]; then
+ rm -f "$file.tmp"
+ echo "" 1>&2
+ echo "Error while downloading. Make sure port 80 is open." 1>&2
+ exit 1
+ else
+ mv "$file.tmp" "$file"
+ fi
+ fi
+fi
+
+echo ""
+if [ -f "python.tmp" ]; then
+ rm -f python.tmp
+ rm -f -r "$pythonpath"
+fi
+if [ -d "$pythonpath" ]; then
+ echo "Python already installed in $pythonpath"
+else
+ echo "Installing Python in $pythonpath..."
+ echo "" > python.tmp
+ bash "$(pwd)/$file" -b -f -p "$pythonpath"
+ rm -f python.tmp
+ if [ $? != 0 ]; then
+ rm -f -r "$pythonpath"
+ echo "" 1>&2
+ echo "Error while installing. Make sure you have write permissions." 1>&2
+ exit 1
+ fi
+fi
+
+PATH="$PATH;$pythonpath/bin"
+"$pythonpath/bin/python" -u -B -X utf8 -m torchstudio.pythoninstall --channel $channel $cuda $packages
+if [ $? != 0 ]; then
+ echo "" 1>&2
+ echo "Error while installing packages" 1>&2
+ exit 1
+fi
+
+echo ""
+echo "Installation complete."
+)
+exit
+
+
+:BATCH
+
+@echo off
+setlocal
+cd /D "%~dp0"
+cd ..
+
+set pythonpath=%cd%\python
+set channel=pytorch
+set cuda=
+set packages=
+set uninstall=
+:args
+if "%~1" == "--path" (
+ set pythonpath=%~2
+ shift
+) else if "%~1" == "--channel" (
+ set channel=%~2
+ shift
+) else if "%~1" == "--cuda" (
+ set cuda=--cuda
+) else if "%~1" == "--package" (
+ set packages=%packages% --package %~2
+ shift
+) else if "%~1" == "--uninstall" (
+ set uninstall=1
+) else if "%~1" == "" (
+ goto endargs
+)
+shift
+goto args
+:endargs
+
+if DEFINED uninstall (
+ echo Uninstalling python environment...
+ del *.exe 2>nul
+ del *.tmp 2>nul
+ rmdir /s /q "%pythonpath%" 2>nul
+ if ERRORLEVEL 1 (
+ echo. 1>&2
+ echo Error while uninstalling. Check write permissions. 1>&2
+ exit /B 1
+ ) else (
+ echo.
+ echo Uninstall complete.
+ exit /B 0
+ )
+)
+
+echo Downloading, installing and setting up a windows python environment
+if DEFINED cuda (
+ echo This may take up to 16 minutes depending on your download speed, and up to 16 GB.
+) else (
+ echo This may take up to 5 minutes depending on your download speed, and up to 5 GB.
+)
+
+echo.
+set file=Miniconda3-latest-Windows-x86_64.exe
+del %file%.tmp 2>nul
+if EXIST "%file%" (
+ echo Python installer %file% already downloaded
+) else (
+ echo Downloading Python installer %file%...
+ curl --insecure https://repo.anaconda.com/miniconda/%file% -o %file%.tmp
+ if ERRORLEVEL 1 (
+ del %file%.tmp 2>nul
+ echo. 1>&2
+ echo Error while downloading. Make sure port 80 is open. 1>&2
+ exit /B 1
+ ) else (
+ ren %file%.tmp %file%
+ )
+)
+
+echo.
+if EXIST "python.tmp" (
+ del python.tmp 2>nul
+ rmdir /s /q "%pythonpath%" 2>nul
+)
+if EXIST "%pythonpath%" (
+ echo Python already installed in %pythonpath%
+) else (
+ echo Installing Python in %pythonpath%...
+ echo. > python.tmp
+ %file% /S /D=%pythonpath%
+ del python.tmp 2>nul
+ if ERRORLEVEL 1 (
+ rmdir /s /q "%pythonpath%" 2>nul
+ echo. 1>&2
+ echo Error while installing. Make sure you have write permissions. 1>&2
+ exit /B 1
+ )
+)
+
+set PATH=%PATH%;%pythonpath%;%pythonpath%\Library\mingw-w64\bin;%pythonpath%\Library\bin
+"%pythonpath%\python" -u -B -X utf8 -m torchstudio.pythoninstall --channel %channel% %cuda% %packages%
+if ERRORLEVEL 1 (
+ echo. 1>&2
+ echo Error while installing packages 1>&2
+ exit /B 1
+)
+
+echo.
+echo Installation complete.
diff --git a/torchstudio/pythonparse.py b/torchstudio/pythonparse.py
index e369276..9a64fe2 100644
--- a/torchstudio/pythonparse.py
+++ b/torchstudio/pythonparse.py
@@ -1,415 +1,414 @@
-#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490
-import os
-os.environ['KMP_DUPLICATE_LIB_OK']='True'
-
-import importlib
-import inspect, sys
-import ast
-import re
-from typing import Dict, List
-from os import listdir
-from os.path import isfile, join
-import torchstudio.tcpcodec as tc
-from torchstudio.modules import safe_exec
-
-def gather_parameters(node):
- params=[]
- for param in inspect.signature(node).parameters.values():
- #name
- if param.kind==param.VAR_POSITIONAL:
- params.append("*"+param.name)
- elif param.kind==param.VAR_KEYWORD:
- params.append("**"+param.name)
- else:
- params.append(param.name)
- #annotation
- if param.annotation == param.empty:
- params.append('')
- else:
- params.append(param.annotation.__name__ if isinstance(param.annotation, type) else repr(param.annotation))
- #default value
- if param.default == param.empty:
- params.append('')
- elif inspect.isclass(param.default) or inspect.isfunction(param.default):
- params.append(param.default.__module__+'.'+param.default.__name__)
- else:
- value=repr(param.default)
- if "","")
- params.append(value)
- return params
-
-def gather_objects(module):
- objects=[]
- for name, obj in inspect.getmembers(module):
-# print("INSPECTING ", name)
- if ((inspect.isclass(obj) and hasattr(obj, '__mro__') and ("torch.nn.modules.module.Module" in str(obj.__mro__) or "torch.utils.data.dataset.Dataset" in str(obj.__mro__))) or inspect.isfunction(obj)): #filter unwanted torch objects
- object={}
- object['type']='class' if inspect.isclass(obj) else 'function'
- object['name']=name
- if obj.__doc__ is not None:
- object['doc']=inspect.cleandoc(obj.__doc__)
-
- # autofill class members when requested
- doc=object['doc']
- newstring = ''
- start = 0
- for m in re.finditer(".. autoclass:: [a-zA-Z0-9_.]+\n :members:", doc):
- end, newstart = m.span()
- newstring += doc[start:end]
- class_name = re.findall(".. autoclass:: ([a-zA-Z0-9_.]+)\n :members:", m.group(0))
- rep = class_name[0]+":\n"
- sub_error_message, submodule = safe_exec(importlib.import_module,(class_name[0].rpartition('.')[0],))
- if submodule is not None:
- for member in dir(vars(submodule)[class_name[0].rpartition('.')[-1]]):
- if not member.startswith('_'):
- rep+=' '+member+'\n'
- newstring += rep
- start = newstart
- newstring += doc[start:]
- object['doc']=newstring.replace(" :noindex:","")
- else:
- object['doc']=name+(' class' if inspect.isclass(obj) else ' function')
- if hasattr(obj,'__getitem__') and obj.__getitem__.__doc__ is not None:
- itemdoc=inspect.cleandoc(obj.__getitem__.__doc__)
- if 'Returns:' in itemdoc:
- object['doc']+='\n\n'+itemdoc[itemdoc.find('Returns:'):]
- object['params']=gather_parameters(obj.__init__ if inspect.isclass(obj) else obj)
- if inspect.isclass(obj):
- object['params']=object['params'][3:] #remove self parameter
- object['code']=''
- objects.append(object)
- return objects
-
-
-def parse_parameters(node):
- #prepare defaults to be in sync with arguments
- defaults=[]
- for d in node.args.defaults:
- defaults.append(ast.get_source_segment(code,d))
- for d in range(len(node.args.args)-len(node.args.defaults)):
- defaults.insert(0,"")
- #scan through arguments
- params=[]
- for i,a in enumerate(node.args.args):
- params.append(a.arg)
- if a.annotation:
- params.append(ast.get_source_segment(code,a.annotation))
- else:
- params.append("")
- params.append(defaults[i])
- #add *args, if applicable
- if node.args.vararg:
- params.append("*"+node.args.vararg.arg)
- if node.args.vararg.annotation:
- params.append(ast.get_source_segment(code,node.args.vararg.annotation))
- else:
- params.append("")
- params.append("") #no default value
- #add **kwargs, if applicable
- if node.args.kwarg:
- params.append("**"+node.args.kwarg.arg)
- if node.args.kwarg.annotation:
- params.append(ast.get_source_segment(code,node.args.kwarg.annotation))
- else:
- params.append("")
- params.append("") #no default value
- return params
-
-def parse_objects(module:ast.Module):
- objects=[]
- for node in module.body:
- if isinstance(node, ast.FunctionDef):
- object={}
- object['code']=ast.get_source_segment(code,node)
- object['type']='function'
- object['name']=node.name
- object['doc']=ast.get_docstring(node) if ast.get_docstring(node) else ""
- object['params']=parse_parameters(node)
- objects.append(object)
- if isinstance(node, ast.ClassDef):
- object={}
- object['code']=ast.get_source_segment(code,node)
- object['type']='class'
- object['name']=node.name
- object['doc']=ast.get_docstring(node) if ast.get_docstring(node) else ""
- object['params']=[]
- for subnode in node.body:
- if isinstance(subnode, ast.FunctionDef) and subnode.name=="__init__":
- object['params']=parse_parameters(subnode)
- object['params']=object['params'][3:] #remove self parameter
- objects.append(object)
- return objects
-
-def filter_parent_objects(objects:List[Dict]) -> List[Dict]:
- parent_objects=[]
- for object in objects:
- unique=True
- for subobject in objects:
- name=object['name']
- if subobject['name']!=name:
- if re.search('[ =+]'+name+'[ ]*\(', subobject['code']):
- unique=False
- if unique:
- parent_objects.append(object)
- return parent_objects
-
-
-generated_class="""\
-import typing
-import pathlib
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import {0}
-from {0} import transforms
-
-class {1}({2}):
- \"\"\"{3}\"\"\"
- def __init__({4}):
- super().__init__({5})
-"""
-
-generated_function="""\
-import typing
-import pathlib
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import {0}
-from {0} import transforms
-
-def {1}({2}):
- \"\"\"{3}\"\"\"
- {4}={5}({6})
- return {4}
-"""
-
-def generate_code(path,object):
- #write an inherited code code for each object
- name=object['name']
- params=object['params']
- if object['type']=='class':
- return generated_class.format(
- path.split('.')[0],
- name, path+'.'+name,
- object['doc'],
- ', '.join(['self']+[params[i]+(': ' if params[i+1] else'')+params[i+1]+(' = ' if params[i+2] else'')+params[i+2] for i in range(0,len(params),3)]),
- ', '.join([params[i] for i in range(0,len(params),3)])
- )
- else:
- return generated_function.format(
- path.split('.')[0],
- name, ', '.join([params[i]+(': ' if params[i+1] else'')+params[i+1]+(' = ' if params[i+2] else'')+params[i+2] for i in range(0,len(params),3)]),
- object['doc'],
- path.split('.')[1][:-1], path+'.'+name, ', '.join([params[i] for i in range(0,len(params),3)])
- )
-
-def patch_parameters(path, name, params):
- patched_params=[]
- if 'datasets' in path:
- for state in ['train','valid']:
- for i in range(0,len(params)-1,3):
- #patch root for all modules
- if params[i]=='root' and not params[i+2]:
- params[i+2]="'"+data_path+"'"
-
- #patch download for all modules
- if params[i]=="download" and params[i+2]=="False":
- params[i+2]="True"
-
- #patch transform for vision modules
- if 'torchvision' in path and (params[i]=="transform" or params[i]=="target_transform"):
- params[i+2]="transforms.Compose([])"
-
- #patch train/val for specific modules
- if state=='valid':
- if params[i]=="train" and params[i+2]=="True":
- params[i+2]="False"
- if 'torchvision' in path:
- if (name=="Cityscapes" or name=="ImageNet") and params[i]=="split":
- params[i+2]="'val'"
- if (name=="STL10" or name=="SVHN") and params[i]=="split":
- params[i+2]="'test'"
- if (name=="CelebA") and params[i]=="split":
- params[i+2]="'valid'"
- if (name=="Places365") and params[i]=="split":
- params[i+2]="'val'"
- if (name=="VOCDetection" or name=="VOCSegmentation") and params[i]=="image_set":
- params[i+2]="'val'"
-
- patched_params.append(params.copy())
- else:
- patched_params.append(params)
- return patched_params
-
-custom_classes={}
-custom_classes['Custom Dataset']="""\
-import torch
-from torch.utils.data import Dataset
-
-class MyDataset(Dataset):
- def __init__(self):
- super().__init__()
-
- def __len__(self):
- pass
-
- def __getitem__(self, idx):
- pass
-"""
-custom_classes['Custom Renderer']="""\
-from torchstudio.modules import Renderer
-import numpy as np
-import matplotlib as mpl
-import matplotlib.pyplot as plt
-import PIL
-
-class MyRenderer(Renderer):
- def __init__(self):
- super().__init__()
-
- def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]):
- pass
-"""
-custom_classes['Custom Analyzer']="""\
-from torchstudio.modules import Analyzer
-from typing import List
-import numpy as np
-import matplotlib as mpl
-import matplotlib.pyplot as plt
-import PIL
-
-class MyAnalyzer(Analyzer):
- def __init__(self):
- super().__init__()
-
- def start_analysis(self, num_training_samples: int, num_validation_samples: int, input_tensors_id: List[int], output_tensors_id: List[int], labels: List[str]):
- pass
-
- def analyze_sample(self, sample: List[np.array], training_sample: bool):
- pass
-
- def finish_analysis(self):
- pass
-
- def generate_report(self, size: Tuple[int, int], dpi: int):
- pass
-"""
-custom_classes['Custom Model']="""\
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-class MyModel(nn.Module):
- def __init__(self):
- super().__init__()
-
- def forward(self, x):
- pass
-"""
-custom_classes['Custom Loss']="""\
-import torch.nn as nn
-
-class MyLoss(nn.Modules._Loss):
- def __init__(self):
- super().__init__()
-
- def forward(self, x):
- pass
-"""
-custom_classes['Custom Metric']="""\
-from torchstudio.modules import Metric
-import torch.nn.functional as F
-
-class MyMetric(Metric):
- def __init__(self):
- pass
-
- def update(self, preds, target):
- pass
-
- def compute(self):
- pass
-
- def reset(self):
- pass
-"""
-custom_classes['Custom Optimizer']="""\
-import torch.optim as optim
-
-class MyOptimizer(optim.Optimizer):
- def __init__(self, params):
- super().__init__(params)
-"""
-custom_classes['Custom Scheduler']="""\
-import torch.optim as optim
-
-class MyScheduler(optim._LRScheduler):
- def __init__(self, optimizer):
- super().__init__(optimizer)
-"""
-
-def scan_folder(path):
- path=path.replace('.','/')
- codes=[]
- for filename in listdir(path):
- if isfile(join(path, filename)):
- with open(join(path, filename), "r") as file:
- codes.append(file.read())
- return codes
-
-
-app_socket = tc.connect()
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
- objects=[]
- if msg_type == 'SetDataDir':
- data_path=tc.decode_strings(msg_data)[0]
-
- if msg_type == 'Parse': #parse code or path, return a list of objects (class and functions) with their names, doc, parameters, doc and code
- decoded=tc.decode_strings(msg_data)
- path=decoded[0]
- if path in custom_classes:
- decoded.append(custom_classes[path])
- if 'torchstudio' in path:
- decoded+=scan_folder(path)
- if len(decoded)>1:
- #parse code chunks
- for code in decoded[1:]:
- error_msg, module = safe_exec(ast.parse,(code,))
- if error_msg is None and module is not None:
- objects_batch=parse_objects(module)
- objects_batch=filter_parent_objects(objects_batch) #only keep parent objets
- for i in range(len(objects_batch)):
- objects_batch[i]['code']=code #set whole source code for each object, as we don't know the dependencies
- objects.extend(objects_batch)
- else:
- print("Error parsing code:", error_msg, "\n", file=sys.stderr)
- else:
- #parse module
- error_msg, module = safe_exec(importlib.import_module,(path,))
- if error_msg is None and module is not None:
- objects=gather_objects(module)
- for i, object in enumerate(objects):
- objects[i]['code']=generate_code(path,object) #generate inherited source code
- else:
- print("Error parsing module:", error_msg, "\n", file=sys.stderr)
-
- tc.send_msg(app_socket, 'ObjectsBegin', tc.encode_strings(path))
- for object in objects:
- patched_params = patch_parameters(path,object['name'],object['params'])
- for params in patched_params:
- tc.send_msg(app_socket, 'Object', tc.encode_strings([path,object['code'],object['type'],object['name'],object['doc']]+params))
- tc.send_msg(app_socket, 'ObjectsEnd', tc.encode_strings(path))
-
- if msg_type == 'RequestDefinitionName': #return default definition name
- tab=tc.decode_strings(msg_data)[0]
- if tab=='dataset':
- tc.send_msg(app_socket, 'SetDefinitionName', tc.encode_strings(['torchvision.datasets','MNIST']))
- if tab=='model':
- tc.send_msg(app_socket, 'SetDefinitionName', tc.encode_strings(['torchstudio.models','MNISTClassifier']))
-
- if msg_type == 'Exit':
- break
+#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490
+import os
+os.environ['KMP_DUPLICATE_LIB_OK']='True'
+
+import importlib
+import inspect, sys
+import ast
+import re
+from typing import Dict, List
+from os import listdir
+from os.path import isfile, join
+import torchstudio.tcpcodec as tc
+from torchstudio.modules import safe_exec
+
+def gather_parameters(node):
+ params=[]
+ for param in inspect.signature(node).parameters.values():
+ #name
+ if param.kind==param.VAR_POSITIONAL:
+ params.append("*"+param.name)
+ elif param.kind==param.VAR_KEYWORD:
+ params.append("**"+param.name)
+ else:
+ params.append(param.name)
+ #annotation
+ if param.annotation == param.empty:
+ params.append('')
+ else:
+ params.append(param.annotation.__name__ if isinstance(param.annotation, type) else repr(param.annotation))
+ #default value
+ if param.default == param.empty:
+ params.append('')
+ elif inspect.isclass(param.default) or inspect.isfunction(param.default):
+ params.append(param.default.__module__+'.'+param.default.__name__)
+ else:
+ value=repr(param.default)
+ if "","")
+ params.append(value)
+ return params
+
+def gather_objects(module):
+ objects=[]
+ for name, obj in inspect.getmembers(module):
+ if (inspect.isclass(obj) and hasattr(obj, '__mro__') and ("torch.nn.modules.module.Module" in str(obj.__mro__) or "torch.utils.data.dataset.Dataset" in str(obj.__mro__))) or (inspect.isfunction(obj) and "return" in obj.__annotations__ and inspect.isclass(obj.__annotations__["return"]) and "torch.nn.modules.module.Module" in str(obj.__annotations__["return"].__mro__)): #filter unwanted torch objects
+ object={}
+ object['type']='class' if inspect.isclass(obj) else 'function'
+ object['name']=name
+ if obj.__doc__ is not None:
+ object['doc']=inspect.cleandoc(obj.__doc__)
+
+ # autofill class members when requested
+ doc=object['doc']
+ newstring = ''
+ start = 0
+ for m in re.finditer(".. autoclass:: [a-zA-Z0-9_.]+\n :members:", doc):
+ end, newstart = m.span()
+ newstring += doc[start:end]
+ class_name = re.findall(".. autoclass:: ([a-zA-Z0-9_.]+)\n :members:", m.group(0))
+ rep = class_name[0]+":\n"
+ sub_error_message, submodule = safe_exec(importlib.import_module,(class_name[0].rpartition('.')[0],))
+ if submodule is not None:
+ for member in dir(vars(submodule)[class_name[0].rpartition('.')[-1]]):
+ if not member.startswith('_'):
+ rep+=' '+member+'\n'
+ newstring += rep
+ start = newstart
+ newstring += doc[start:]
+ object['doc']=newstring.replace(" :noindex:","")
+ else:
+ object['doc']=name+(' class' if inspect.isclass(obj) else ' function')
+ if hasattr(obj,'__getitem__') and obj.__getitem__.__doc__ is not None:
+ itemdoc=inspect.cleandoc(obj.__getitem__.__doc__)
+ if 'Returns:' in itemdoc:
+ object['doc']+='\n\n'+itemdoc[itemdoc.find('Returns:'):]
+ object['params']=gather_parameters(obj.__init__ if inspect.isclass(obj) else obj)
+ if inspect.isclass(obj):
+ object['params']=object['params'][3:] #remove self parameter
+ object['code']=''
+ objects.append(object)
+ return objects
+
+
+def parse_parameters(node):
+ #prepare defaults to be in sync with arguments
+ defaults=[]
+ for d in node.args.defaults:
+ defaults.append(ast.get_source_segment(code,d))
+ for d in range(len(node.args.args)-len(node.args.defaults)):
+ defaults.insert(0,"")
+ #scan through arguments
+ params=[]
+ for i,a in enumerate(node.args.args):
+ params.append(a.arg)
+ if a.annotation:
+ params.append(ast.get_source_segment(code,a.annotation))
+ else:
+ params.append("")
+ params.append(defaults[i])
+ #add *args, if applicable
+ if node.args.vararg:
+ params.append("*"+node.args.vararg.arg)
+ if node.args.vararg.annotation:
+ params.append(ast.get_source_segment(code,node.args.vararg.annotation))
+ else:
+ params.append("")
+ params.append("") #no default value
+ #add **kwargs, if applicable
+ if node.args.kwarg:
+ params.append("**"+node.args.kwarg.arg)
+ if node.args.kwarg.annotation:
+ params.append(ast.get_source_segment(code,node.args.kwarg.annotation))
+ else:
+ params.append("")
+ params.append("") #no default value
+ return params
+
+def parse_objects(module:ast.Module):
+ objects=[]
+ for node in module.body:
+ if isinstance(node, ast.FunctionDef):
+ object={}
+ object['code']=ast.get_source_segment(code,node)
+ object['type']='function'
+ object['name']=node.name
+ object['doc']=ast.get_docstring(node) if ast.get_docstring(node) else ""
+ object['params']=parse_parameters(node)
+ objects.append(object)
+ if isinstance(node, ast.ClassDef):
+ object={}
+ object['code']=ast.get_source_segment(code,node)
+ object['type']='class'
+ object['name']=node.name
+ object['doc']=ast.get_docstring(node) if ast.get_docstring(node) else ""
+ object['params']=[]
+ for subnode in node.body:
+ if isinstance(subnode, ast.FunctionDef) and subnode.name=="__init__":
+ object['params']=parse_parameters(subnode)
+ object['params']=object['params'][3:] #remove self parameter
+ objects.append(object)
+ return objects
+
+def filter_parent_objects(objects:List[Dict]) -> List[Dict]:
+ parent_objects=[]
+ for object in objects:
+ unique=True
+ for subobject in objects:
+ name=object['name']
+ if subobject['name']!=name:
+ if re.search('[ =+]'+name+'[ ]*\(', subobject['code']):
+ unique=False
+ if unique:
+ parent_objects.append(object)
+ return parent_objects
+
+
+generated_class="""\
+import typing
+import pathlib
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import {0}
+from {0} import transforms
+
+class {1}({2}):
+ \"\"\"{3}\"\"\"
+ def __init__({4}):
+ super().__init__({5})
+"""
+
+generated_function="""\
+import typing
+import pathlib
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import {0}
+from {0} import transforms
+
+def {1}({2}):
+ \"\"\"{3}\"\"\"
+ {4}={5}({6})
+ return {4}
+"""
+
+def generate_code(path,object):
+ #write an inherited code code for each object
+ name=object['name']
+ params=object['params']
+ if object['type']=='class':
+ return generated_class.format(
+ path.split('.')[0],
+ name, path+'.'+name,
+ object['doc'],
+ ', '.join(['self']+[params[i]+(': ' if params[i+1] else'')+params[i+1]+(' = ' if params[i+2] else'')+params[i+2] for i in range(0,len(params),3)]),
+ ', '.join([params[i] for i in range(0,len(params),3)])
+ )
+ else:
+ return generated_function.format(
+ path.split('.')[0],
+ name, ', '.join([params[i]+(': ' if params[i+1] else'')+params[i+1]+(' = ' if params[i+2] else'')+params[i+2] for i in range(0,len(params),3)]),
+ object['doc'],
+ path.split('.')[1][:-1], path+'.'+name, ', '.join([params[i] for i in range(0,len(params),3)])
+ )
+
+def patch_parameters(path, name, params):
+ patched_params=[]
+ if 'datasets' in path:
+ for state in ['train','valid']:
+ for i in range(0,len(params)-1,3):
+ #patch root for all modules
+ if params[i]=='root' and not params[i+2]:
+ params[i+2]="'"+data_path+"'"
+
+ #patch download for all modules
+ if params[i]=="download" and params[i+2]=="False":
+ params[i+2]="True"
+
+ #patch transform for vision modules
+ if 'torchvision' in path and (params[i]=="transform" or params[i]=="target_transform"):
+ params[i+2]="transforms.Compose([])"
+
+ #patch train/val for specific modules
+ if state=='valid':
+ if params[i]=="train" and params[i+2]=="True":
+ params[i+2]="False"
+ if 'torchvision' in path:
+ if (name=="Cityscapes" or name=="ImageNet") and params[i]=="split":
+ params[i+2]="'val'"
+ if (name=="STL10" or name=="SVHN") and params[i]=="split":
+ params[i+2]="'test'"
+ if (name=="CelebA") and params[i]=="split":
+ params[i+2]="'valid'"
+ if (name=="Places365") and params[i]=="split":
+ params[i+2]="'val'"
+ if (name=="VOCDetection" or name=="VOCSegmentation") and params[i]=="image_set":
+ params[i+2]="'val'"
+
+ patched_params.append(params.copy())
+ else:
+ patched_params.append(params)
+ return patched_params
+
+custom_classes={}
+custom_classes['Custom Dataset']="""\
+import torch
+from torch.utils.data import Dataset
+
+class MyDataset(Dataset):
+ def __init__(self):
+ super().__init__()
+
+ def __len__(self):
+ pass
+
+ def __getitem__(self, idx):
+ pass
+"""
+custom_classes['Custom Renderer']="""\
+from torchstudio.modules import Renderer
+import numpy as np
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import PIL
+
+class MyRenderer(Renderer):
+ def __init__(self):
+ super().__init__()
+
+ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]):
+ pass
+"""
+custom_classes['Custom Analyzer']="""\
+from torchstudio.modules import Analyzer
+from typing import List
+import numpy as np
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import PIL
+
+class MyAnalyzer(Analyzer):
+ def __init__(self):
+ super().__init__()
+
+ def start_analysis(self, num_training_samples: int, num_validation_samples: int, input_tensors_id: List[int], output_tensors_id: List[int], labels: List[str]):
+ pass
+
+ def analyze_sample(self, sample: List[np.array], training_sample: bool):
+ pass
+
+ def finish_analysis(self):
+ pass
+
+ def generate_report(self, size: Tuple[int, int], dpi: int):
+ pass
+"""
+custom_classes['Custom Model']="""\
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class MyModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ pass
+"""
+custom_classes['Custom Loss']="""\
+import torch.nn as nn
+
+class MyLoss(nn.Modules._Loss):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ pass
+"""
+custom_classes['Custom Metric']="""\
+from torchstudio.modules import Metric
+import torch.nn.functional as F
+
+class MyMetric(Metric):
+ def __init__(self):
+ pass
+
+ def update(self, preds, target):
+ pass
+
+ def compute(self):
+ pass
+
+ def reset(self):
+ pass
+"""
+custom_classes['Custom Optimizer']="""\
+import torch.optim as optim
+
+class MyOptimizer(optim.Optimizer):
+ def __init__(self, params):
+ super().__init__(params)
+"""
+custom_classes['Custom Scheduler']="""\
+import torch.optim as optim
+
+class MyScheduler(optim._LRScheduler):
+ def __init__(self, optimizer):
+ super().__init__(optimizer)
+"""
+
+def scan_folder(path):
+ path=path.replace('.','/')
+ codes=[]
+ for filename in listdir(path):
+ if isfile(join(path, filename)):
+ with open(join(path, filename), "r") as file:
+ codes.append(file.read())
+ return codes
+
+
+app_socket = tc.connect()
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+ objects=[]
+ if msg_type == 'SetDataDir':
+ data_path=tc.decode_strings(msg_data)[0]
+
+ if msg_type == 'Parse': #parse code or path, return a list of objects (class and functions) with their names, doc, parameters, doc and code
+ decoded=tc.decode_strings(msg_data)
+ path=decoded[0]
+ if path in custom_classes:
+ decoded.append(custom_classes[path])
+ if 'torchstudio' in path:
+ decoded+=scan_folder(path)
+ if len(decoded)>1:
+ #parse code chunks
+ for code in decoded[1:]:
+ error_msg, module = safe_exec(ast.parse,(code,))
+ if error_msg is None and module is not None:
+ objects_batch=parse_objects(module)
+ objects_batch=filter_parent_objects(objects_batch) #only keep parent objets
+ for i in range(len(objects_batch)):
+ objects_batch[i]['code']=code #set whole source code for each object, as we don't know the dependencies
+ objects.extend(objects_batch)
+ else:
+ print("Error parsing code:", error_msg, "\n", file=sys.stderr)
+ else:
+ #parse module
+ error_msg, module = safe_exec(importlib.import_module,(path,))
+ if error_msg is None and module is not None:
+ objects=gather_objects(module)
+ for i, object in enumerate(objects):
+ objects[i]['code']=generate_code(path,object) #generate inherited source code
+ else:
+ print("Error parsing module:", error_msg, "\n", file=sys.stderr)
+
+ tc.send_msg(app_socket, 'ObjectsBegin', tc.encode_strings(path))
+ for object in objects:
+ patched_params = patch_parameters(path,object['name'],object['params'])
+ for params in patched_params:
+ tc.send_msg(app_socket, 'Object', tc.encode_strings([path,object['code'],object['type'],object['name'],object['doc']]+params))
+ tc.send_msg(app_socket, 'ObjectsEnd', tc.encode_strings(path))
+
+ if msg_type == 'RequestDefinitionName': #return default definition name
+ tab=tc.decode_strings(msg_data)[0]
+ if tab=='dataset':
+ tc.send_msg(app_socket, 'SetDefinitionName', tc.encode_strings(['torchvision.datasets','MNIST']))
+ if tab=='model':
+ tc.send_msg(app_socket, 'SetDefinitionName', tc.encode_strings(['torchstudio.models','MNISTClassifier']))
+
+ if msg_type == 'Exit':
+ break
diff --git a/torchstudio/sshtunnel.py b/torchstudio/sshtunnel.py
index 59bc2a6..5a9c551 100644
--- a/torchstudio/sshtunnel.py
+++ b/torchstudio/sshtunnel.py
@@ -1,337 +1,340 @@
-import time
-import sys
-import os
-import io
-
-# Port forwarding from https://github.com/skyleronken/sshrat/blob/master/tunnels.py
-# improved with dynamic local port allocation feedback for reverse tunnel with a null local port
-import threading
-import socket
-import selectors
-import time
-import socketserver
-import paramiko
-import hashlib
-
-class Tunnel():
-
- def __init__(self, ssh_session, tun_type, lhost, lport, dhost, dport):
- self.tun_type = tun_type
- self.lhost = lhost
- self.lport = lport
- self.dhost = dhost
- self.dport = dport
-
- # create tunnel here
- if self.tun_type == ForwardTunnel:
- self.tunnel = ForwardTunnel(ssh_session, self.lhost, self.lport, self.dhost, self.dport)
- elif self.tun_type == ReverseTunnel:
- self.tunnel = ReverseTunnel(ssh_session, self.lhost, self.lport, self.dhost, self.dport)
- self.lport = self.tunnel.lport #in case of dynamic allocation (lport=0)
-
- def to_str(self):
- if self.tun_type == ForwardTunnel:
- return f"{self.lhost}:{self.lport} --> {self.dhost}:{self.dport}"
- else:
- return f"{self.dhost}:{self.dport} <-- {self.lhost}:{self.lport}"
-
- def stop(self):
- self.tunnel.stop()
-
-class ReverseTunnel():
-
- def __init__(self, ssh_session, lhost, lport, dhost, dport):
- self.session = ssh_session
- self.lhost = lhost
- self.lport = lport
- self.dhost = dhost
- self.dport = dport
-
- self.transport = ssh_session.get_transport()
-
- self.reverse_forward_tunnel(lhost, lport, dhost, dport, self.transport)
- self.handlers = []
-
- def stop(self):
- self.transport.cancel_port_forward(self.lhost, self.lport)
- for thr in self.handlers:
- thr.stop()
-
- def handler(self, rev_socket, origin, laddress):
- rev_handler = ReverseTunnelHandler(rev_socket, self.dhost, self.dport, self.lhost, self.lport)
- rev_handler.setDaemon(True)
- rev_handler.start()
- self.handlers.append(rev_handler)
-
- def reverse_forward_tunnel(self, lhost, lport, dhost, dport, transport):
- try:
- self.lport=transport.request_port_forward(lhost, lport, handler=self.handler)
- except Exception as e:
- raise e
-
-class ReverseTunnelHandler(threading.Thread):
-
- def __init__(self, rev_socket, dhost, dport, lhost, lport):
-
- threading.Thread.__init__(self)
-
- self.rev_socket = rev_socket
- self.dhost = dhost
- self.dport = dport
- self.lhost = lhost
- self.lport = lport
-
- self.dst_socket = socket.socket()
- try:
- self.dst_socket.connect((self.dhost, self.dport))
- except Exception as e:
- raise e
-
- self.keepalive = True
-
- def _read_from_rev(self, dst, rev):
- self._transfer_data(src_socket=rev,dest_socket=dst)
-
- def _read_from_dest(self, dst, rev):
- self._transfer_data(src_socket=dst,dest_socket=rev)
-
- def _transfer_data(self,src_socket,dest_socket):
- dest_socket.setblocking(True)
- data = src_socket.recv(1024)
-
- if len(data):
- try:
- dest_socket.send(data)
- except Exception as e:
- print(f"ssh error: {type(e).__name__}", file=sys.stderr)
-
- def stop(self):
- self.rev_socket.shutdown(2)
- self.dst_socket.shutdown(2)
- self.rev_socket.close()
- self.dst_socket.close()
- self.keepalive = False
-
- def run(self):
- selector = selectors.DefaultSelector()
-
- selector.register(fileobj=self.rev_socket,events=selectors.EVENT_READ,data=self._read_from_rev)
- selector.register(fileobj=self.dst_socket,events=selectors.EVENT_READ,data=self._read_from_dest)
-
- while self.keepalive:
- events = selector.select(5)
- if len(events) > 0:
- for key, _ in events:
- callback = key.data
- try:
- callback(dst=self.dst_socket,rev=self.rev_socket)
- except Exception as e:
- print(f"ssh error: {type(e).__name__}", file=sys.stderr)
- time.sleep(0)
-
-
-
-# credits to paramiko-tunnel
-class ForwardTunnel(socketserver.ThreadingTCPServer):
- daemon_threads = True
- allow_reuse_address = True
-
- def __init__(self, ssh_session, lhost, lport, dhost, dport):
- self.session = ssh_session
- self.lhost = lhost
- self.lport = lport
- self.dhost = dhost
- self.dport = dport
-
- super().__init__(
- server_address=(lhost, lport),
- RequestHandlerClass=ForwardTunnelHandler,
- bind_and_activate=True,
- )
-
- self.baddr, self.bport = self.server_address
- self.thread = threading.Thread(
- target=self.serve_forever,
- daemon=True,
- )
-
- self.start()
-
- def start(self):
- self.thread.start()
-
- def stop(self):
- self.shutdown()
- self.server_close()
-
- def __enter__(self):
- self.start()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.shutdown()
-
-class ForwardTunnelHandler(socketserver.BaseRequestHandler):
- sz_buf = 1024
-
- def __init__(self, request, cli_addr, server):
- self.selector = selectors.DefaultSelector()
- self.channel = None
- super().__init__(request, cli_addr, server)
-
- def _read_from_client(self, sock, mask):
- self._transfer_data(src_socket=sock, dest_socket=self.channel)
-
- def _read_from_channel(self, sock, mask):
- self._transfer_data(src_socket=sock,dest_socket=self.request)
-
- def _transfer_data(self,src_socket,dest_socket):
- src_socket.setblocking(True)
- data = src_socket.recv(self.sz_buf)
-
- if len(data):
- try:
- dest_socket.send(data)
- except BrokenPipeError:
- self.finish()
-
- def handle(self):
- peer_name = self.request.getpeername()
- try:
- self.channel = self.server.session.get_transport().open_channel(
- kind='direct-tcpip',
- dest_addr=(self.server.dhost,self.server.dport,),
- src_addr=peer_name,
- )
- except Exception as error:
- msg = f'Connection failed to {self.server.dhost}:{self.server.dport}'
- raise Exception(msg)
-
- else:
- self.selector.register(fileobj=self.channel,events=selectors.EVENT_READ,data=self._read_from_channel)
- self.selector.register(fileobj=self.request,events=selectors.EVENT_READ,data=self._read_from_client)
-
- if self.channel is None:
- self.finish()
- raise Exception(f'SSH Server rejected request to {self.server.dhost}:{self.server.dport}')
-
- while True:
- events = self.selector.select()
- for key, mask in events:
- callback = key.data
- callback(sock=key.fileobj,mask=mask)
- if self.server._BaseServer__is_shut_down.is_set():
- self.finish()
- time.sleep(0)
-
- def finish(self):
- if self.channel is not None:
- self.channel.shutdown(how=2)
- self.channel.close()
- self.request.shutdown(2)
- self.request.close()
-
-###
-
-
-
-if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--sshaddress", help="server address", type=str, default=None)
- parser.add_argument("--sshport", help="ssh server port", type=int, default=22)
- parser.add_argument("--username", help="ssh server username", type=str, default=None)
- parser.add_argument("--password", help="ssh server password", type=str, default=None)
- parser.add_argument("--keyfile", help="ssh server key file", type=str, default=None)
- parser.add_argument('--clean', help="clean all files", action='store_true', default=False)
- parser.add_argument("--command", help="command to execute or run python scripts", type=str, default=None)
- parser.add_argument("--script", help="python script to be launched on the server", type=str, default=None)
- parser.add_argument("--address", help="address to which the script must connect", type=str, default=None)
- parser.add_argument("--port", help="port to which the script must connect", type=int, default=None)
- args, other_args = parser.parse_known_args()
-
- sshclient = paramiko.SSHClient()
- sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-
- print("Connecting to remote server...", file=sys.stderr)
- try:
- sshclient.connect(hostname=args.sshaddress, port=args.sshport, username=args.username, password=args.password, pkey=paramiko.RSAKey.from_private_key_file(args.keyfile) if args.keyfile else None, timeout=5)
- except:
- print("Error: could not connect to remote server", file=sys.stderr)
- exit(1)
-
- if args.clean:
- print("Cleaning...", file=sys.stderr)
- stdin, stdout, stderr = sshclient.exec_command('rm -r -f TorchStudio')
- exit_status = stdout.channel.recv_exit_status()
- stdin, stdout, stderr = sshclient.exec_command('rmdir /s /q TorchStudio')
- exit_status = stdout.channel.recv_exit_status()
- sshclient.close()
- print("Cleaning complete")
- exit(0)
-
- #copy root scripts to the remote server if necessary
- print("Updating remote scripts...", file=sys.stderr)
- local_scripts_hash = hashlib.md5()
- for entry in os.scandir('torchstudio'):
- if entry.is_file():
- with open('torchstudio/'+entry.name, 'rb') as f:
- local_scripts_hash.update(f.read())
- local_scripts_hash = local_scripts_hash.digest()
- sftp = paramiko.SFTPClient.from_transport(sshclient.get_transport())
- remote_scripts_hash=io.BytesIO()
- try:
- sftp.getfo('TorchStudio/torchstudio/.md5', remote_scripts_hash)
- except:
- pass
- if remote_scripts_hash.getvalue()!=local_scripts_hash:
- try:
- sftp.mkdir('TorchStudio')
- except:
- pass
- try:
- sftp.mkdir('TorchStudio/torchstudio')
- except:
- pass
- try:
- for entry in os.scandir('torchstudio'):
- if entry.is_file():
- sftp.put('torchstudio/'+entry.name, 'TorchStudio/torchstudio/'+entry.name)
- if entry.name.endswith('.cmd'):
- sftp.chmod('TorchStudio/torchstudio/'+entry.name, 0o0777)
- except:
- print("Error: could not update remote scripts", file=sys.stderr)
- exit(1)
- new_scripts_hash=io.BytesIO(local_scripts_hash)
- sftp.putfo(new_scripts_hash, 'TorchStudio/torchstudio/.md5')
- sftp.close()
-
- if args.address:
- other_args=["--address", args.address]+other_args
-
- if args.port:
- #setup remote port forwarding
- print("Forwarding ports...", file=sys.stderr)
- reverse_tunnel = Tunnel(sshclient, ReverseTunnel, 'localhost', 0, args.address if args.address else 'localhost', args.port) #remote address, remote port, local address, local port
- other_args=["--port", str(reverse_tunnel.lport)]+other_args
-
- if args.command:
- if args.script:
- print("Launching remote script...", file=sys.stderr)
- stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+args.command+" -u -X utf8 -m "+' '.join([args.script]+other_args))
- else:
- print("Executing remote command...", file=sys.stderr)
- stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+' '.join([args.command]+other_args))
- while True:
- time.sleep(.1)
- if stdout.channel.recv_ready():
- sys.stdout.write(str(stdout.channel.recv(1024),'utf-8'))
- if stdout.channel.recv_stderr_ready():
- sys.stderr.write(str(stdout.channel.recv_stderr(1024),'utf-8'))
- if stdout.channel.exit_status_ready():
- break
- else:
- print("Error: no python command set. Define a command or refresh to install a python environment.", file=sys.stderr)
-
- sshclient.close()
-
+import time
+import sys
+import os
+import io
+
+# Port forwarding from https://github.com/skyleronken/sshrat/blob/master/tunnels.py
+# improved with dynamic local port allocation feedback for reverse tunnel with a null local port
+import threading
+import socket
+import selectors
+import time
+import socketserver
+import paramiko
+import hashlib
+
+class Tunnel():
+
+ def __init__(self, ssh_session, tun_type, lhost, lport, dhost, dport):
+ self.tun_type = tun_type
+ self.lhost = lhost
+ self.lport = lport
+ self.dhost = dhost
+ self.dport = dport
+
+ # create tunnel here
+ if self.tun_type == ForwardTunnel:
+ self.tunnel = ForwardTunnel(ssh_session, self.lhost, self.lport, self.dhost, self.dport)
+ elif self.tun_type == ReverseTunnel:
+ self.tunnel = ReverseTunnel(ssh_session, self.lhost, self.lport, self.dhost, self.dport)
+ self.lport = self.tunnel.lport #in case of dynamic allocation (lport=0)
+
+ def to_str(self):
+ if self.tun_type == ForwardTunnel:
+ return f"{self.lhost}:{self.lport} --> {self.dhost}:{self.dport}"
+ else:
+ return f"{self.dhost}:{self.dport} <-- {self.lhost}:{self.lport}"
+
+ def stop(self):
+ self.tunnel.stop()
+
+class ReverseTunnel():
+
+ def __init__(self, ssh_session, lhost, lport, dhost, dport):
+ self.session = ssh_session
+ self.lhost = lhost
+ self.lport = lport
+ self.dhost = dhost
+ self.dport = dport
+
+ self.transport = ssh_session.get_transport()
+
+ self.reverse_forward_tunnel(lhost, lport, dhost, dport, self.transport)
+ self.handlers = []
+
+ def stop(self):
+ self.transport.cancel_port_forward(self.lhost, self.lport)
+ for thr in self.handlers:
+ thr.stop()
+
+ def handler(self, rev_socket, origin, laddress):
+ rev_handler = ReverseTunnelHandler(rev_socket, self.dhost, self.dport, self.lhost, self.lport)
+ rev_handler.setDaemon(True)
+ rev_handler.start()
+ self.handlers.append(rev_handler)
+
+ def reverse_forward_tunnel(self, lhost, lport, dhost, dport, transport):
+ try:
+ self.lport=transport.request_port_forward(lhost, lport, handler=self.handler)
+ except Exception as e:
+ raise e
+
+class ReverseTunnelHandler(threading.Thread):
+
+ def __init__(self, rev_socket, dhost, dport, lhost, lport):
+
+ threading.Thread.__init__(self)
+
+ self.rev_socket = rev_socket
+ self.dhost = dhost
+ self.dport = dport
+ self.lhost = lhost
+ self.lport = lport
+
+ self.dst_socket = socket.socket()
+ try:
+ self.dst_socket.connect((self.dhost, self.dport))
+ except Exception as e:
+ raise e
+
+ self.keepalive = True
+
+ def _read_from_rev(self, dst, rev):
+ self._transfer_data(src_socket=rev,dest_socket=dst)
+
+ def _read_from_dest(self, dst, rev):
+ self._transfer_data(src_socket=dst,dest_socket=rev)
+
+ def _transfer_data(self,src_socket,dest_socket):
+ dest_socket.setblocking(True)
+ data = src_socket.recv(1024)
+
+ if len(data):
+ try:
+ dest_socket.send(data)
+ except Exception as e:
+ print(f"ssh error: {type(e).__name__}", file=sys.stderr)
+
+ def stop(self):
+ self.rev_socket.shutdown(2)
+ self.dst_socket.shutdown(2)
+ self.rev_socket.close()
+ self.dst_socket.close()
+ self.keepalive = False
+
+ def run(self):
+ selector = selectors.DefaultSelector()
+
+ selector.register(fileobj=self.rev_socket,events=selectors.EVENT_READ,data=self._read_from_rev)
+ selector.register(fileobj=self.dst_socket,events=selectors.EVENT_READ,data=self._read_from_dest)
+
+ while self.keepalive:
+ events = selector.select(5)
+ if len(events) > 0:
+ for key, _ in events:
+ callback = key.data
+ try:
+ callback(dst=self.dst_socket,rev=self.rev_socket)
+ except Exception as e:
+ print(f"ssh error: {type(e).__name__}", file=sys.stderr)
+ time.sleep(0)
+
+
+
+# credits to paramiko-tunnel
+class ForwardTunnel(socketserver.ThreadingTCPServer):
+ daemon_threads = True
+ allow_reuse_address = True
+
+ def __init__(self, ssh_session, lhost, lport, dhost, dport):
+ self.session = ssh_session
+ self.lhost = lhost
+ self.lport = lport
+ self.dhost = dhost
+ self.dport = dport
+
+ super().__init__(
+ server_address=(lhost, lport),
+ RequestHandlerClass=ForwardTunnelHandler,
+ bind_and_activate=True,
+ )
+
+ self.baddr, self.bport = self.server_address
+ self.thread = threading.Thread(
+ target=self.serve_forever,
+ daemon=True,
+ )
+
+ self.start()
+
+ def start(self):
+ self.thread.start()
+
+ def stop(self):
+ self.shutdown()
+ self.server_close()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.shutdown()
+
+class ForwardTunnelHandler(socketserver.BaseRequestHandler):
+ sz_buf = 1024
+
+ def __init__(self, request, cli_addr, server):
+ self.selector = selectors.DefaultSelector()
+ self.channel = None
+ super().__init__(request, cli_addr, server)
+
+ def _read_from_client(self, sock, mask):
+ self._transfer_data(src_socket=sock, dest_socket=self.channel)
+
+ def _read_from_channel(self, sock, mask):
+ self._transfer_data(src_socket=sock,dest_socket=self.request)
+
+ def _transfer_data(self,src_socket,dest_socket):
+ src_socket.setblocking(True)
+ data = src_socket.recv(self.sz_buf)
+
+ if len(data):
+ try:
+ dest_socket.send(data)
+ except BrokenPipeError:
+ self.finish()
+
+ def handle(self):
+ peer_name = self.request.getpeername()
+ try:
+ self.channel = self.server.session.get_transport().open_channel(
+ kind='direct-tcpip',
+ dest_addr=(self.server.dhost,self.server.dport,),
+ src_addr=peer_name,
+ )
+ except Exception as error:
+ msg = f'Connection failed to {self.server.dhost}:{self.server.dport}'
+ raise Exception(msg)
+
+ else:
+ self.selector.register(fileobj=self.channel,events=selectors.EVENT_READ,data=self._read_from_channel)
+ self.selector.register(fileobj=self.request,events=selectors.EVENT_READ,data=self._read_from_client)
+
+ if self.channel is None:
+ self.finish()
+ raise Exception(f'SSH Server rejected request to {self.server.dhost}:{self.server.dport}')
+
+ while True:
+ events = self.selector.select()
+ for key, mask in events:
+ callback = key.data
+ callback(sock=key.fileobj,mask=mask)
+ if self.server._BaseServer__is_shut_down.is_set():
+ self.finish()
+ time.sleep(0)
+
+ def finish(self):
+ if self.channel is not None:
+ self.channel.shutdown(how=2)
+ self.channel.close()
+ self.request.shutdown(2)
+ self.request.close()
+
+###
+
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--sshaddress", help="server address", type=str, default=None)
+ parser.add_argument("--sshport", help="ssh server port", type=int, default=22)
+ parser.add_argument("--username", help="ssh server username", type=str, default=None)
+ parser.add_argument("--password", help="ssh server password", type=str, default=None)
+ parser.add_argument("--keyfile", help="ssh server key file", type=str, default=None)
+ parser.add_argument('--clean', help="clean all files", action='store_true', default=False)
+ parser.add_argument("--command", help="command to execute or run python scripts", type=str, default=None)
+ parser.add_argument("--script", help="python script to be launched on the server", type=str, default=None)
+ parser.add_argument("--address", help="address to which the script must connect", type=str, default=None)
+ parser.add_argument("--port", help="port to which the script must connect", type=int, default=None)
+ args, other_args = parser.parse_known_args()
+
+ sshclient = paramiko.SSHClient()
+ sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+
+ print("Connecting to remote server...", file=sys.stderr)
+ try:
+ sshclient.connect(hostname=args.sshaddress, port=args.sshport, username=args.username, password=args.password, pkey=paramiko.RSAKey.from_private_key_file(args.keyfile) if args.keyfile else None, timeout=5)
+ except:
+ print("Error: could not connect to remote server", file=sys.stderr)
+ exit(1)
+
+ if args.clean:
+ print("Cleaning...", file=sys.stderr)
+ stdin, stdout, stderr = sshclient.exec_command('rm -r -f TorchStudio')
+ exit_status = stdout.channel.recv_exit_status()
+ stdin, stdout, stderr = sshclient.exec_command('rmdir /s /q TorchStudio')
+ exit_status = stdout.channel.recv_exit_status()
+ sshclient.close()
+ print("Cleaning complete")
+ exit(0)
+
+ #copy root scripts to the remote server if necessary
+ print("Updating remote scripts...", file=sys.stderr)
+ local_scripts_hash = hashlib.md5()
+ for entry in os.scandir('torchstudio'):
+ if entry.is_file():
+ with open('torchstudio/'+entry.name, 'rb') as f:
+ local_scripts_hash.update(f.read())
+ local_scripts_hash = local_scripts_hash.digest()
+ sftp = paramiko.SFTPClient.from_transport(sshclient.get_transport())
+ remote_scripts_hash=io.BytesIO()
+ try:
+ sftp.getfo('TorchStudio/torchstudio/.md5', remote_scripts_hash)
+ except:
+ pass
+ if remote_scripts_hash.getvalue()!=local_scripts_hash:
+ try:
+ sftp.mkdir('TorchStudio')
+ except:
+ pass
+ try:
+ sftp.mkdir('TorchStudio/torchstudio')
+ except:
+ pass
+ try:
+ for entry in os.scandir('torchstudio'):
+ if entry.is_file():
+ sftp.put('torchstudio/'+entry.name, 'TorchStudio/torchstudio/'+entry.name)
+ if entry.name.endswith('.cmd'):
+ sftp.chmod('TorchStudio/torchstudio/'+entry.name, 0o0777)
+ except:
+ print("Error: could not update remote scripts", file=sys.stderr)
+ exit(1)
+ new_scripts_hash=io.BytesIO(local_scripts_hash)
+ sftp.putfo(new_scripts_hash, 'TorchStudio/torchstudio/.md5')
+ sftp.close()
+
+ if args.address:
+ other_args=["--address", args.address]+other_args
+
+ if args.port:
+ #setup remote port forwarding
+ print("Forwarding ports...", file=sys.stderr)
+ reverse_tunnel = Tunnel(sshclient, ReverseTunnel, 'localhost', 0, args.address if args.address else 'localhost', args.port) #remote address, remote port, local address, local port
+ other_args=["--port", str(reverse_tunnel.lport)]+other_args
+
+ if args.command:
+ if args.script:
+ print("Launching remote script...", file=sys.stderr)
+ stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+args.command+" -u -X utf8 -m "+' '.join([args.script]+other_args))
+ else:
+ print("Executing remote command...", file=sys.stderr)
+ stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+' '.join([args.command]+other_args))
+ while True:
+ time.sleep(.1)
+ if stdout.channel.recv_stderr_ready():
+ sys.stderr.write(str(stdout.channel.recv_stderr(1024).replace(b'\r\n',b'\n'),'utf-8'))
+ if stdout.channel.recv_ready():
+ sys.stdout.write(str(stdout.channel.recv(1024).replace(b'\r\n',b'\n'),'utf-8'))
+ if stdout.channel.exit_status_ready():
+ break
+ else:
+ if args.script:
+ print("Error: no python environment set.", file=sys.stderr)
+ else:
+ print("Error: no command set.", file=sys.stderr)
+
+ sshclient.close()
+
diff --git a/torchstudio/tensorrender.py b/torchstudio/tensorrender.py
index 510a9e4..15703e9 100644
--- a/torchstudio/tensorrender.py
+++ b/torchstudio/tensorrender.py
@@ -1,70 +1,70 @@
-import torchstudio.tcpcodec as tc
-from torchstudio.modules import safe_exec
-import inspect
-import sys
-import os
-
-title = ''
-tensor = None
-resolution = (256,256, 96)
-shift = (0,0,0,0)
-scale = (1,1,1,1)
-input_tensors = []
-target_tensor = None
-labels = []
-
-app_socket = tc.connect()
-while True:
- msg_type, msg_data = tc.recv_msg(app_socket)
-
- if msg_type == 'SetRendererCode':
- error_msg, renderer_env = safe_exec(tc.decode_strings(msg_data)[0],description='renderer definition')
- if error_msg is not None or 'renderer' not in renderer_env:
- print("Unknown renderer definition error" if error_msg is None else error_msg, file=sys.stderr)
- else:
- tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(renderer_env['renderer'].__doc__) if renderer_env['renderer'].__doc__ is not None else ""))
-
- if msg_type == 'Clear':
- tensor = None
- input_tensors = []
- target_tensor = None
-
- if msg_type == 'SetTitle':
- title = tc.decode_strings(msg_data)[0]
-
- if msg_type == 'TensorData':
- tensor = tc.decode_numpy_tensors(msg_data)[0]
-
- if msg_type == 'SetResolution':
- resolution = tc.decode_ints(msg_data)
-
- if msg_type == 'SetShift':
- shift = tc.decode_floats(msg_data)
- if msg_type == 'SetScale':
- scale = tc.decode_floats(msg_data)
-
- if msg_type == 'SetInputTensors':
- input_tensors = tc.decode_numpy_tensors(msg_data)
-
- if msg_type == 'SetTargetTensors':
- target_tensors = tc.decode_numpy_tensors(msg_data)
- if target_tensors:
- target_tensor=target_tensors[0]
- else:
- target_tensor=None
-
- if msg_type == 'SetLabels':
- labels = tc.decode_strings(msg_data)
-
- if msg_type == 'Render':
- if 'renderer' in renderer_env and tensor is not None and resolution[0]>0 and resolution[1]>0:
- error_msg, img = safe_exec(renderer_env['renderer'].render, (title, tensor,resolution[0:2],resolution[2],shift,scale,input_tensors,target_tensor,labels), description='renderer definition')
- if error_msg is not None:
- print(error_msg, file=sys.stderr)
- if img is None:
- tc.send_msg(app_socket, 'ImageError')
- else:
- tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
-
- if msg_type == 'Exit':
- break
+import torchstudio.tcpcodec as tc
+from torchstudio.modules import safe_exec
+import inspect
+import sys
+import os
+
+title = ''
+tensor = None
+resolution = (256,256, 96)
+shift = (0,0,0,0)
+scale = (1,1,1,1)
+input_tensors = []
+target_tensor = None
+labels = []
+
+app_socket = tc.connect()
+while True:
+ msg_type, msg_data = tc.recv_msg(app_socket)
+
+ if msg_type == 'SetRendererCode':
+ error_msg, renderer_env = safe_exec(tc.decode_strings(msg_data)[0],description='renderer definition')
+ if error_msg is not None or 'renderer' not in renderer_env:
+ print("Unknown renderer definition error" if error_msg is None else error_msg, file=sys.stderr)
+ else:
+ tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(renderer_env['renderer'].__doc__) if renderer_env['renderer'].__doc__ is not None else ""))
+
+ if msg_type == 'Clear':
+ tensor = None
+ input_tensors = []
+ target_tensor = None
+
+ if msg_type == 'SetTitle':
+ title = tc.decode_strings(msg_data)[0]
+
+ if msg_type == 'TensorData':
+ tensor = tc.decode_numpy_tensors(msg_data)[0]
+
+ if msg_type == 'SetResolution':
+ resolution = tc.decode_ints(msg_data)
+
+ if msg_type == 'SetShift':
+ shift = tc.decode_floats(msg_data)
+ if msg_type == 'SetScale':
+ scale = tc.decode_floats(msg_data)
+
+ if msg_type == 'SetInputTensors':
+ input_tensors = tc.decode_numpy_tensors(msg_data)
+
+ if msg_type == 'SetTargetTensors':
+ target_tensors = tc.decode_numpy_tensors(msg_data)
+ if target_tensors:
+ target_tensor=target_tensors[0]
+ else:
+ target_tensor=None
+
+ if msg_type == 'SetLabels':
+ labels = tc.decode_strings(msg_data)
+
+ if msg_type == 'Render':
+ if 'renderer' in renderer_env and tensor is not None and resolution[0]>0 and resolution[1]>0:
+ error_msg, img = safe_exec(renderer_env['renderer'].render, (title, tensor,resolution[0:2],resolution[2],shift,scale,input_tensors,target_tensor,labels), description='renderer definition')
+ if error_msg is not None:
+ print(error_msg, file=sys.stderr)
+ if img is None:
+ tc.send_msg(app_socket, 'ImageError')
+ else:
+ tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
+
+ if msg_type == 'Exit':
+ break