From bde996cc66d410f9777779b23011a85f1046e542 Mon Sep 17 00:00:00 2001 From: Robin Lobel Date: Wed, 14 Sep 2022 19:07:38 +0200 Subject: [PATCH] 0.9.10 --- torchstudio/datasetanalyze.py | 234 ++--- torchstudio/datasetload.py | 572 +++++----- torchstudio/datasets/randomgenerator.py | 100 +- torchstudio/graphdraw.py | 1276 +++++++++++------------ torchstudio/metricsplot.py | 370 +++---- torchstudio/modelbuild.py | 492 ++++----- torchstudio/models/unet1d.py | 286 ++--- torchstudio/models/unet2d.py | 332 +++--- torchstudio/modeltrain.py | 724 ++++++------- torchstudio/parametersplot.py | 340 +++--- torchstudio/pythoninstall.cmd | 478 ++++----- torchstudio/pythonparse.py | 829 ++++++++------- torchstudio/sshtunnel.py | 677 ++++++------ torchstudio/tensorrender.py | 140 +-- 14 files changed, 3428 insertions(+), 3422 deletions(-) diff --git a/torchstudio/datasetanalyze.py b/torchstudio/datasetanalyze.py index 929bf1a..6b9239e 100644 --- a/torchstudio/datasetanalyze.py +++ b/torchstudio/datasetanalyze.py @@ -1,117 +1,117 @@ -import sys - -import torchstudio.tcpcodec as tc -from torchstudio.modules import safe_exec -import os -import io -from collections.abc import Iterable -from tqdm.auto import tqdm -import pickle - -original_path=sys.path - -app_socket = tc.connect() -print("Analyze script connected\n", file=sys.stderr) -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - - if msg_type == 'SetAnalyzerCode': - print("Setting analyzer code...\n", file=sys.stderr) - analyzer = None - analyzer_code = tc.decode_strings(msg_data)[0] - error_msg, analyzer_env = safe_exec(analyzer_code, description='analyzer definition') - if error_msg is not None or 'analyzer' not in analyzer_env: - print("Unknown analyzer definition error" if error_msg is None else error_msg, file=sys.stderr) - - if msg_type == 'StartAnalysisServer' and 'analyzer' in analyzer_env: - print("Analyzing...\n", file=sys.stderr) - - analysis_server, address = tc.generate_server() - - if analyzer_env['analyzer'].train is None: - request_msg='AnalysisServerRequestingAllSamples' - elif analyzer_env['analyzer'].train==True: - request_msg='AnalysisServerRequestingTrainingSamples' - elif analyzer_env['analyzer'].train==False: - request_msg='AnalysisServerRequestingValidationSamples' - tc.send_msg(app_socket, request_msg, tc.encode_strings(address)) - dataset_socket=tc.start_server(analysis_server) - - while True: - dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket) - - if dataset_msg_type == 'NumSamples': - num_samples=tc.decode_ints(dataset_msg_data)[0] - pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters - - if dataset_msg_type == 'InputTensorsID': - input_tensors_id=tc.decode_ints(dataset_msg_data) - - if dataset_msg_type == 'OutputTensorsID': - output_tensors_id=tc.decode_ints(dataset_msg_data) - - if dataset_msg_type == 'Labels': - labels=tc.decode_strings(dataset_msg_data) - - if dataset_msg_type == 'StartSending': - error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition') - if error_msg is not None: - pbar.close() - print(error_msg, file=sys.stderr) - dataset_socket.close() - analysis_server.close() - break - - if dataset_msg_type == 'TrainingSample': - pbar.update(1) - error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), True), description='analyzer definition') - if error_msg is not None: - pbar.close() - print(error_msg, file=sys.stderr) - dataset_socket.close() - analysis_server.close() - break - - if dataset_msg_type == 'ValidationSample': - pbar.update(1) - error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), False), description='analyzer definition') - if error_msg is not None: - pbar.close() - print(error_msg, file=sys.stderr) - dataset_socket.close() - analysis_server.close() - break - - if dataset_msg_type == 'DoneSending': - pbar.close() - error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition') - tc.send_msg(dataset_socket, 'DoneReceiving') - dataset_socket.close() - analysis_server.close() - if error_msg is not None: - print(error_msg, file=sys.stderr) - else: - buffer=io.BytesIO() - pickle.dump(analyzer_env['analyzer'].state_dict(), buffer) - tc.send_msg(app_socket, 'AnalyzerState',buffer.getvalue()) - tc.send_msg(app_socket, 'AnalysisWeights',tc.encode_floats(analyzer_env['analyzer'].weights)) - print("Analysis complete") - break - - if msg_type == 'LoadAnalyzerState': - if 'analyzer' in analyzer_env: - buffer=io.BytesIO(msg_data) - analyzer_env['analyzer'].load_state_dict(pickle.load(buffer)) - print("Analyzer state loaded") - - if msg_type == 'RequestAnalysisReport': - resolution = tc.decode_ints(msg_data) - if 'analyzer' in analyzer_env: - error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition') - if error_msg is not None: - print(error_msg, file=sys.stderr) - if return_value is not None: - tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value)) - - if msg_type == 'Exit': - break +import sys + +import torchstudio.tcpcodec as tc +from torchstudio.modules import safe_exec +import os +import io +from collections.abc import Iterable +from tqdm.auto import tqdm +import pickle + +original_path=sys.path + +app_socket = tc.connect() +print("Analyze script connected\n", file=sys.stderr) +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + + if msg_type == 'SetAnalyzerCode': + print("Setting analyzer code...\n", file=sys.stderr) + analyzer = None + analyzer_code = tc.decode_strings(msg_data)[0] + error_msg, analyzer_env = safe_exec(analyzer_code, description='analyzer definition') + if error_msg is not None or 'analyzer' not in analyzer_env: + print("Unknown analyzer definition error" if error_msg is None else error_msg, file=sys.stderr) + + if msg_type == 'StartAnalysisServer' and 'analyzer' in analyzer_env: + print("Analyzing...\n", file=sys.stderr) + + analysis_server, address = tc.generate_server() + + if analyzer_env['analyzer'].train is None: + request_msg='AnalysisServerRequestingAllSamples' + elif analyzer_env['analyzer'].train==True: + request_msg='AnalysisServerRequestingTrainingSamples' + elif analyzer_env['analyzer'].train==False: + request_msg='AnalysisServerRequestingValidationSamples' + tc.send_msg(app_socket, request_msg, tc.encode_strings(address)) + dataset_socket=tc.start_server(analysis_server) + + while True: + dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket) + + if dataset_msg_type == 'NumSamples': + num_samples=tc.decode_ints(dataset_msg_data)[0] + pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters + + if dataset_msg_type == 'InputTensorsID': + input_tensors_id=tc.decode_ints(dataset_msg_data) + + if dataset_msg_type == 'OutputTensorsID': + output_tensors_id=tc.decode_ints(dataset_msg_data) + + if dataset_msg_type == 'Labels': + labels=tc.decode_strings(dataset_msg_data) + + if dataset_msg_type == 'StartSending': + error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition') + if error_msg is not None: + pbar.close() + print(error_msg, file=sys.stderr) + dataset_socket.close() + analysis_server.close() + break + + if dataset_msg_type == 'TrainingSample': + pbar.update(1) + error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), True), description='analyzer definition') + if error_msg is not None: + pbar.close() + print(error_msg, file=sys.stderr) + dataset_socket.close() + analysis_server.close() + break + + if dataset_msg_type == 'ValidationSample': + pbar.update(1) + error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), False), description='analyzer definition') + if error_msg is not None: + pbar.close() + print(error_msg, file=sys.stderr) + dataset_socket.close() + analysis_server.close() + break + + if dataset_msg_type == 'DoneSending': + pbar.close() + error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition') + tc.send_msg(dataset_socket, 'DoneReceiving') + dataset_socket.close() + analysis_server.close() + if error_msg is not None: + print(error_msg, file=sys.stderr) + else: + buffer=io.BytesIO() + pickle.dump(analyzer_env['analyzer'].state_dict(), buffer) + tc.send_msg(app_socket, 'AnalyzerState',buffer.getvalue()) + tc.send_msg(app_socket, 'AnalysisWeights',tc.encode_floats(analyzer_env['analyzer'].weights)) + print("Analysis complete") + break + + if msg_type == 'LoadAnalyzerState': + if 'analyzer' in analyzer_env: + buffer=io.BytesIO(msg_data) + analyzer_env['analyzer'].load_state_dict(pickle.load(buffer)) + print("Analyzer state loaded") + + if msg_type == 'RequestAnalysisReport': + resolution = tc.decode_ints(msg_data) + if 'analyzer' in analyzer_env: + error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition') + if error_msg is not None: + print(error_msg, file=sys.stderr) + if return_value is not None: + tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value)) + + if msg_type == 'Exit': + break diff --git a/torchstudio/datasetload.py b/torchstudio/datasetload.py index 6e0845a..9b9f9cc 100644 --- a/torchstudio/datasetload.py +++ b/torchstudio/datasetload.py @@ -1,286 +1,286 @@ -#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490 -import os -os.environ['KMP_DUPLICATE_LIB_OK']='True' - -import sys -print("Loading PyTorch...\n", file=sys.stderr) - -import torch -from torch.utils.data import Dataset -from torchvision.transforms.functional import to_tensor -import torchstudio.tcpcodec as tc -from torchstudio.modules import safe_exec -import random -import os -import io -import time -from collections.abc import Iterable -from tqdm.auto import tqdm - -#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error -import ssl -ssl._create_default_https_context = ssl._create_unverified_context - -meta_dataset = None -input_tensors_id = [] -output_tensors_id = [] - -class MetaDataset(Dataset): - def __init__(self, train, valid=None): - self.train_dataset=train - self.valid_dataset=valid - self.train_count=None - self.shuffle=0 - self.smp_usage=1.0 - self.training=True - self.classes=train.classes if hasattr(train,'classes') else [] - self._gen_index() - - def _gen_index(self): - self.index = [] - for i in range(len(self.train_dataset)): - self.index.append((self.train_dataset,i)) - if self.valid_dataset is not None: - for i in range(len(self.valid_dataset)): - self.index.append((self.valid_dataset,i)) - if self.train_count is None: - self.train_count=len(self.train_dataset) - if self.valid_dataset is None: - self.train_count=round(self.train_count*0.8) - - if self.shuffle>0: - #Fisher–Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle - random.seed(0) - shuffle_count=self.train_count if self.shuffle==1 else len(self.index) - for sample in range(shuffle_count): - target_sample=random.randrange(sample,shuffle_count) - self.index[sample], self.index[target_sample] = self.index[target_sample], self.index[sample] - - def set_num_train(self, num): - self.train_count=min(num,len(self.index)) - self._gen_index() - - def set_smp_usage(self, ratio): - self.smp_usage=min(max(ratio,0.0),1.0) - self._gen_index() - - def set_shuffle(self, mode): - self.shuffle=mode - self._gen_index() - - def train(self, mode=True): - self.training=mode - return self - - def valid(self): - self.training=False - return self - - def __len__(self): - if self.training==True: - return round(self.train_count*self.smp_usage) - else: - return round((len(self.index)-self.train_count)*self.smp_usage) - - def __getitem__(self, id): - if id<0 or id>=len(self): - raise IndexError - if self.training==True: - sample_ref=self.index[id] - else: - sample_ref=self.index[id+self.train_count] - sample=sample_ref[0][sample_ref[1]] - - #convert to list if needed - if isinstance(sample, Iterable): - if type(sample) is dict: - sample=list(sample.values()) - else: - sample=list(sample) - else: - sample=[sample] - - #convert each element of the list to a tensor if needed - sample_tensors=[] - for i in range(len(sample)): - if type(sample[i]) is not torch.Tensor: - if 'PIL' in str(type(sample[i])) or 'numpy' in str(type(sample[i])): - sample_tensors.append(to_tensor(sample[i])) - else: - try: - sample_tensors.append(torch.tensor(sample[i])) - except: - pass - else: - sample_tensors.append(sample[i]) - - #and finally solidify into a tuple - sample_tensors=tuple(sample_tensors) - - return sample_tensors - -original_path=sys.path -original_dir=os.getcwd() - -app_socket = tc.connect() -print("Dataset script connected\n", file=sys.stderr) -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - - if msg_type == 'SetCurrentDir': - new_dir=tc.decode_strings(msg_data)[0] - sys.path=original_path - os.chdir(original_dir) - if new_dir: - sys.path.append(new_dir) - os.chdir(new_dir) - - if msg_type == 'SetDatasetCode': - print("Loading dataset...\n", file=sys.stderr) - - meta_dataset = None - error_msg, dataset_env = safe_exec(tc.decode_strings(msg_data)[0], description='dataset definition') - if error_msg is not None or 'train' not in dataset_env: - print("Unknown dataset definition error" if error_msg is None else error_msg, file=sys.stderr) - else: - meta_dataset=MetaDataset(dataset_env['train'], dataset_env['valid'] if 'valid' in dataset_env else None) - tc.send_msg(app_socket, 'Labels', tc.encode_strings(meta_dataset.classes)) - tc.send_msg(app_socket, 'NumSamples', tc.encode_ints([len(meta_dataset.train()),len(meta_dataset.valid())])) - sample=meta_dataset.train()[0] - - #suggest default formats - type_id=[1 for i in range(len(sample))] #inputs - if len(sample)==1: - type_id[-1]=3 #input/output - if len(sample)>1: - type_id[-1]=2 #output - tc.send_msg(app_socket, 'SetTypes', tc.encode_ints(type_id)) - - renderer_name=[] - for tensor in sample: - if len(tensor.shape)==4: - renderer_name.append("Volume") - elif len(tensor.shape)==3 and tensor.dtype==torch.complex64: - renderer_name.append("Spectrogram") - elif len(tensor.shape)==3: - renderer_name.append("Bitmap") - elif len(tensor.shape)==2: - renderer_name.append("Signal") - elif len(tensor.shape)<2: - renderer_name.append("Labels") - else: - renderer_name.append("Custom") - tc.send_msg(app_socket, 'SetRendererNames', tc.encode_strings(renderer_name)) - - if sample and len(sample[-1].shape)==0 and "int" in str(sample[-1].dtype): - analyzer_name="Multiclass" - elif sample and len(sample[-1].shape)==1: - analyzer_name="MultiLabel" - else: - analyzer_name="ValuesDistribution" - tc.send_msg(app_socket, 'SetAnalyzerName', tc.encode_strings(analyzer_name)) - - print("Loading complete") - - if msg_type == 'RequestTrainingSamples' or msg_type == 'RequestValidationSamples': - if meta_dataset is not None: - meta_dataset.train(msg_type == 'RequestTrainingSamples') - samples_id = tc.decode_ints(msg_data) - for id in samples_id: - tc.send_msg(app_socket, 'TensorData', tc.encode_torch_tensors(meta_dataset[id])) - - if msg_type == 'SetNumTrainingSamples': - if meta_dataset is not None: - meta_dataset.set_num_train(tc.decode_ints(msg_data)[0]) - - if msg_type == 'SetSampleUsage': - if meta_dataset is not None: - meta_dataset.set_smp_usage(tc.decode_floats(msg_data)[0]) - - if msg_type == 'SetShuffleMode': - if meta_dataset is not None: - meta_dataset.set_shuffle(tc.decode_ints(msg_data)[0]) - - if msg_type == 'InputTensorsID': - input_tensors_id = tc.decode_ints(msg_data) - - if msg_type == 'OutputTensorsID': - output_tensors_id = tc.decode_ints(msg_data) - - if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples': - train_set=True if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendAllSamples' else False - valid_set=True if msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples' else False - name, sshaddress, sshport, username, password, keydata, address, port = tc.decode_strings(msg_data) - port=int(port) - - print('Connecting to '+name+'...\n', file=sys.stderr) - - if sshaddress and sshport and username: - import socket - import paramiko - import torchstudio.sshtunnel as sshtunnel - - if not password: - password=None - if not keydata: - pkey=None - else: - import io - keybuffer=io.StringIO(keydata) - pkey=paramiko.RSAKey.from_private_key(keybuffer) - - sshclient = paramiko.SSHClient() - sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5) - worker_socket = socket.socket() - worker_socket.bind(('localhost', 0)) - freeport=worker_socket.getsockname()[1] - worker_socket.close() - forward_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ForwardTunnel, 'localhost', freeport, address if address else 'localhost', port) - port=freeport - - try: - worker_socket = tc.connect((address,port)) - num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0) - tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples)) - tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id)) - tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id)) - tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes)) - - tc.send_msg(worker_socket, 'StartSending') - with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar: - if train_set: - meta_dataset.train() - for i in range(len(meta_dataset)): - tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i])) - pbar.update(1) - if valid_set: - meta_dataset.valid() - for i in range(len(meta_dataset)): - tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i])) - pbar.update(1) - - tc.send_msg(worker_socket, 'DoneSending') - train_msg_type, train_msg_data = tc.recv_msg(worker_socket) - if train_msg_type == 'DoneReceiving': - worker_socket.close() - print('Samples transfer to '+name+' completed') - - except: - if sshaddress and sshport and username: - time.sleep(.5) #let some time for threaded ssh error messages to print first - print('Samples transfer to '+name+' interrupted', file=sys.stderr) - - if sshaddress and sshport and username: - try: - forward_tunnel.stop() - except: - pass - try: - sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong - except: - pass - - if msg_type == 'Exit': - break - +#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490 +import os +os.environ['KMP_DUPLICATE_LIB_OK']='True' + +import sys +print("Loading PyTorch...\n", file=sys.stderr) + +import torch +from torch.utils.data import Dataset +from torchvision.transforms.functional import to_tensor +import torchstudio.tcpcodec as tc +from torchstudio.modules import safe_exec +import random +import os +import io +import time +from collections.abc import Iterable +from tqdm.auto import tqdm + +#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error +import ssl +ssl._create_default_https_context = ssl._create_unverified_context + +meta_dataset = None +input_tensors_id = [] +output_tensors_id = [] + +class MetaDataset(Dataset): + def __init__(self, train, valid=None): + self.train_dataset=train + self.valid_dataset=valid + self.train_count=None + self.shuffle=0 + self.smp_usage=1.0 + self.training=True + self.classes=train.classes if hasattr(train,'classes') else [] + self._gen_index() + + def _gen_index(self): + self.index = [] + for i in range(len(self.train_dataset)): + self.index.append((self.train_dataset,i)) + if self.valid_dataset is not None: + for i in range(len(self.valid_dataset)): + self.index.append((self.valid_dataset,i)) + if self.train_count is None: + self.train_count=len(self.train_dataset) + if self.valid_dataset is None: + self.train_count=round(self.train_count*0.8) + + if self.shuffle>0: + #Fisher–Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle + random.seed(0) + shuffle_count=self.train_count if self.shuffle==1 else len(self.index) + for sample in range(shuffle_count): + target_sample=random.randrange(sample,shuffle_count) + self.index[sample], self.index[target_sample] = self.index[target_sample], self.index[sample] + + def set_num_train(self, num): + self.train_count=min(num,len(self.index)) + self._gen_index() + + def set_smp_usage(self, ratio): + self.smp_usage=min(max(ratio,0.0),1.0) + self._gen_index() + + def set_shuffle(self, mode): + self.shuffle=mode + self._gen_index() + + def train(self, mode=True): + self.training=mode + return self + + def valid(self): + self.training=False + return self + + def __len__(self): + if self.training==True: + return round(self.train_count*self.smp_usage) + else: + return round((len(self.index)-self.train_count)*self.smp_usage) + + def __getitem__(self, id): + if id<0 or id>=len(self): + raise IndexError + if self.training==True: + sample_ref=self.index[id] + else: + sample_ref=self.index[id+self.train_count] + sample=sample_ref[0][sample_ref[1]] + + #convert to list if needed + if isinstance(sample, Iterable): + if type(sample) is dict: + sample=list(sample.values()) + else: + sample=list(sample) + else: + sample=[sample] + + #convert each element of the list to a tensor if needed + sample_tensors=[] + for i in range(len(sample)): + if type(sample[i]) is not torch.Tensor: + if 'PIL' in str(type(sample[i])) or 'numpy' in str(type(sample[i])): + sample_tensors.append(to_tensor(sample[i])) + else: + try: + sample_tensors.append(torch.tensor(sample[i])) + except: + pass + else: + sample_tensors.append(sample[i]) + + #and finally solidify into a tuple + sample_tensors=tuple(sample_tensors) + + return sample_tensors + +original_path=sys.path +original_dir=os.getcwd() + +app_socket = tc.connect() +print("Dataset script connected\n", file=sys.stderr) +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + + if msg_type == 'SetCurrentDir': + new_dir=tc.decode_strings(msg_data)[0] + sys.path=original_path + os.chdir(original_dir) + if new_dir: + sys.path.append(new_dir) + os.chdir(new_dir) + + if msg_type == 'SetDatasetCode': + print("Loading dataset...\n", file=sys.stderr) + + meta_dataset = None + error_msg, dataset_env = safe_exec(tc.decode_strings(msg_data)[0], description='dataset definition') + if error_msg is not None or 'train' not in dataset_env: + print("Unknown dataset definition error" if error_msg is None else error_msg, file=sys.stderr) + else: + meta_dataset=MetaDataset(dataset_env['train'], dataset_env['valid'] if 'valid' in dataset_env else None) + tc.send_msg(app_socket, 'Labels', tc.encode_strings(meta_dataset.classes)) + tc.send_msg(app_socket, 'NumSamples', tc.encode_ints([len(meta_dataset.train()),len(meta_dataset.valid())])) + sample=meta_dataset.train()[0] + + #suggest default formats + type_id=[1 for i in range(len(sample))] #inputs + if len(sample)==1: + type_id[-1]=3 #input/output + if len(sample)>1: + type_id[-1]=2 #output + tc.send_msg(app_socket, 'SetTypes', tc.encode_ints(type_id)) + + renderer_name=[] + for tensor in sample: + if len(tensor.shape)==4: + renderer_name.append("Volume") + elif len(tensor.shape)==3 and tensor.dtype==torch.complex64: + renderer_name.append("Spectrogram") + elif len(tensor.shape)==3: + renderer_name.append("Bitmap") + elif len(tensor.shape)==2: + renderer_name.append("Signal") + elif len(tensor.shape)<2: + renderer_name.append("Labels") + else: + renderer_name.append("Custom") + tc.send_msg(app_socket, 'SetRendererNames', tc.encode_strings(renderer_name)) + + if sample and len(sample[-1].shape)==0 and "int" in str(sample[-1].dtype): + analyzer_name="Multiclass" + elif sample and len(sample[-1].shape)==1: + analyzer_name="MultiLabel" + else: + analyzer_name="ValuesDistribution" + tc.send_msg(app_socket, 'SetAnalyzerName', tc.encode_strings(analyzer_name)) + + print("Loading complete") + + if msg_type == 'RequestTrainingSamples' or msg_type == 'RequestValidationSamples': + if meta_dataset is not None: + meta_dataset.train(msg_type == 'RequestTrainingSamples') + samples_id = tc.decode_ints(msg_data) + for id in samples_id: + tc.send_msg(app_socket, 'TensorData', tc.encode_torch_tensors(meta_dataset[id])) + + if msg_type == 'SetNumTrainingSamples': + if meta_dataset is not None: + meta_dataset.set_num_train(tc.decode_ints(msg_data)[0]) + + if msg_type == 'SetSampleUsage': + if meta_dataset is not None: + meta_dataset.set_smp_usage(tc.decode_floats(msg_data)[0]) + + if msg_type == 'SetShuffleMode': + if meta_dataset is not None: + meta_dataset.set_shuffle(tc.decode_ints(msg_data)[0]) + + if msg_type == 'InputTensorsID': + input_tensors_id = tc.decode_ints(msg_data) + + if msg_type == 'OutputTensorsID': + output_tensors_id = tc.decode_ints(msg_data) + + if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples': + train_set=True if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendAllSamples' else False + valid_set=True if msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples' else False + name, sshaddress, sshport, username, password, keydata, address, port = tc.decode_strings(msg_data) + port=int(port) + + print('Connecting to '+name+'...\n', file=sys.stderr) + + if sshaddress and sshport and username: + import socket + import paramiko + import torchstudio.sshtunnel as sshtunnel + + if not password: + password=None + if not keydata: + pkey=None + else: + import io + keybuffer=io.StringIO(keydata) + pkey=paramiko.RSAKey.from_private_key(keybuffer) + + sshclient = paramiko.SSHClient() + sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5) + worker_socket = socket.socket() + worker_socket.bind(('localhost', 0)) + freeport=worker_socket.getsockname()[1] + worker_socket.close() + forward_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ForwardTunnel, 'localhost', freeport, address if address else 'localhost', port) + port=freeport + + try: + worker_socket = tc.connect((address,port)) + num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0) + tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples)) + tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id)) + tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id)) + tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes)) + + tc.send_msg(worker_socket, 'StartSending') + with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar: + if train_set: + meta_dataset.train() + for i in range(len(meta_dataset)): + tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i])) + pbar.update(1) + if valid_set: + meta_dataset.valid() + for i in range(len(meta_dataset)): + tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i])) + pbar.update(1) + + tc.send_msg(worker_socket, 'DoneSending') + train_msg_type, train_msg_data = tc.recv_msg(worker_socket) + if train_msg_type == 'DoneReceiving': + worker_socket.close() + print('Samples transfer to '+name+' completed') + + except: + if sshaddress and sshport and username: + time.sleep(.5) #let some time for threaded ssh error messages to print first + print('Samples transfer to '+name+' interrupted', file=sys.stderr) + + if sshaddress and sshport and username: + try: + forward_tunnel.stop() + except: + pass + try: + sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong + except: + pass + + if msg_type == 'Exit': + break + diff --git a/torchstudio/datasets/randomgenerator.py b/torchstudio/datasets/randomgenerator.py index 482d72b..fbb7b9b 100644 --- a/torchstudio/datasets/randomgenerator.py +++ b/torchstudio/datasets/randomgenerator.py @@ -1,50 +1,50 @@ -import torch -from torch.utils.data import Dataset -import inspect - -class RandomGenerator(Dataset): - """A random generator that returns randomly generated tensors - - Args: - size (int): - Size of the dataset (number of samples) - tensors: - A list of tuples defining tensor properties: shape, type, range - All properties are optionals. Defaults are null, float, [0,1] - """ - - def __init__(self, size:int=256, tensors=[(3,64,64), (int,[0,9])]): - torch.manual_seed(0) - self.size = size - self.tensors = tensors - - def __len__(self): - return self.size - - def __getitem__(self, idx): - """ - Returns: - A tuple of tensors. - """ - sample = [] - for properties in self.tensors: - shape=[] - dtype=float - drange=[0,1] - for property in properties: - if type(property)==int: - shape.append(property) - elif inspect.isclass(property): - dtype=property - elif type(property) is list: - drange=property - shape=tuple(shape) - - if 'int' in str(dtype): - tensor=torch.randint(low=drange[0], high=drange[1]+1, size=shape, dtype=dtype) - else: - tensor=torch.rand(size=shape,dtype=dtype)*(drange[1]-drange[0])+drange[0] - - sample.append(tensor) - - return tuple(sample) +import torch +from torch.utils.data import Dataset +import inspect + +class RandomGenerator(Dataset): + """A random generator that returns randomly generated tensors + + Args: + size (int): + Size of the dataset (number of samples) + tensors: + A list of tuples defining tensor properties: shape, type, range + All properties are optionals. Defaults are null, float, [0,1] + """ + + def __init__(self, size:int=256, tensors=[(3,64,64), (int,[0,9])]): + torch.manual_seed(0) + self.size = size + self.tensors = tensors + + def __len__(self): + return self.size + + def __getitem__(self, idx): + """ + Returns: + A tuple of tensors. + """ + sample = [] + for properties in self.tensors: + shape=[] + dtype=float + drange=[0,1] + for property in properties: + if type(property)==int: + shape.append(property) + elif inspect.isclass(property): + dtype=property + elif type(property) is list: + drange=property + shape=tuple(shape) + + if 'int' in str(dtype): + tensor=torch.randint(low=drange[0], high=drange[1]+1, size=shape, dtype=dtype) + else: + tensor=torch.rand(size=shape,dtype=dtype)*(drange[1]-drange[0])+drange[0] + + sample.append(tensor) + + return tuple(sample) diff --git a/torchstudio/graphdraw.py b/torchstudio/graphdraw.py index cc9c34e..fffc353 100644 --- a/torchstudio/graphdraw.py +++ b/torchstudio/graphdraw.py @@ -1,638 +1,638 @@ -import torchstudio.tcpcodec as tc -import os -import graphviz -import copy - -#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/torch.rst -#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/nn.rst -#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/nn.functional.rst - -creation_ops=""" -tensor -sparse_coo_tensor -as_tensor -as_strided -from_numpy -frombuffer -zeros -zeros_like -ones -ones_like -arange -range -linspace -logspace -eye -empty -empty_like -empty_strided -full -full_like -quantize_per_tensor -quantize_per_channel -dequantize -complex -polar -heaviside -""" - -manipulation_ops=""" -cat -concat -conj -chunk -dsplit -column_stack -dstack -gather -hsplit -hstack -index_select -masked_select -movedim -moveaxis -narrow -nonzero -permute -reshape -row_stack -scatter -scatter_add -split -squeeze -stack -swapaxes -swapdims -t -take -take_along_dim -tensor_split -tile -transpose -unbind -unsqueeze -vsplit -vstack -where -""" - -convolution_ops=""" -nn.Conv1d -nn.Conv2d -nn.Conv3d -nn.ConvTranspose1d -nn.ConvTranspose2d -nn.ConvTranspose3d -nn.LazyConv1d -nn.LazyConv2d -nn.LazyConv3d -nn.LazyConvTranspose1d -nn.LazyConvTranspose2d -nn.LazyConvTranspose3d -nn.Unfold -nn.Fold -conv1d -conv2d -conv3d -conv_transpose1d -conv_transpose2d -conv_transpose3d -unfold -fold -""" - -pooling_ops=""" -nn.MaxPool1d -nn.MaxPool2d -nn.MaxPool3d -nn.MaxUnpool1d -nn.MaxUnpool2d -nn.MaxUnpool3d -nn.AvgPool1d -nn.AvgPool2d -nn.AvgPool3d -nn.FractionalMaxPool2d -nn.FractionalMaxPool3d -nn.LPPool1d -nn.LPPool2d -nn.AdaptiveMaxPool1d -nn.AdaptiveMaxPool2d -nn.AdaptiveMaxPool3d -nn.AdaptiveAvgPool1d -nn.AdaptiveAvgPool2d -nn.AdaptiveAvgPool3d -avg_pool1d -avg_pool2d -avg_pool3d -max_pool1d -max_pool2d -max_pool3d -max_unpool1d -max_unpool2d -max_unpool3d -lp_pool1d -lp_pool2d -adaptive_max_pool1d -adaptive_max_pool2d -adaptive_max_pool3d -adaptive_avg_pool1d -adaptive_avg_pool2d -adaptive_avg_pool3d -fractional_max_pool2d -fractional_max_pool3d -""" - -activation_ops=""" -nn.ELU -nn.Hardshrink -nn.Hardsigmoid -nn.Hardtanh -nn.Hardswish -nn.LeakyReLU -nn.LogSigmoid -nn.MultiheadAttention -nn.PReLU -nn.ReLU -nn.ReLU6 -nn.RReLU -nn.SELU -nn.CELU -nn.GELU -nn.Sigmoid -nn.SiLU -nn.Mish -nn.Softplus -nn.Softshrink -nn.Softsign -nn.Tanh -nn.Tanhshrink -nn.Threshold -nn.GLU -nn.Softmin -nn.Softmax -nn.Softmax2d -nn.LogSoftmax -nn.AdaptiveLogSoftmaxWithLoss -threshold -threshold_ -relu -relu_ -hardtanh -hardtanh_ -hardswish -relu6 -elu -elu_ -selu -celu -leaky_relu -leaky_relu_ -prelu -rrelu -rrelu_ -glu -gelu -logsigmoid -hardshrink -tanhshrink -softsign -softplus -softmin -softmax -softshrink -gumbel_softmax -log_softmax -tanh -sigmoid -hardsigmoid -silu -mish -batch_norm -group_norm -instance_norm -layer_norm -local_response_norm -normalize -""" - -normalization_ops=""" -nn.BatchNorm1d -nn.BatchNorm2d -nn.BatchNorm3d -nn.LazyBatchNorm1d -nn.LazyBatchNorm2d -nn.LazyBatchNorm3d -nn.GroupNorm -nn.SyncBatchNorm -nn.InstanceNorm1d -nn.InstanceNorm2d -nn.InstanceNorm3d -nn.LazyInstanceNorm1d -nn.LazyInstanceNorm2d -nn.LazyInstanceNorm3d -nn.LayerNorm -nn.LocalResponseNorm -""" - -linear_ops=""" -nn.Identity -nn.Linear -nn.Bilinear -nn.LazyLinear -linear -bilinear -""" - -dropout_ops=""" -nn.Dropout -nn.Dropout2d -nn.Dropout3d -nn.AlphaDropout -nn.FeatureAlphaDropout -dropout -alpha_dropout -feature_alpha_dropout -dropout2d -dropout3d -""" - -vision_ops=""" -nn.PixelShuffle -nn.PixelUnshuffle -nn.Upsample -nn.UpsamplingNearest2d -nn.UpsamplingBilinear2d -pixel_shuffle -pixel_unshuffle -pad -interpolate -upsample -upsample_nearest -upsample_bilinear -grid_sample -affine_grid -""" - -math_ops=""" -abs -absolute -acos -arccos -acosh -arccosh -add -addcdiv -addcmul -angle -asin -arcsin -asinh -arcsinh -atan -arctan -atanh -arctanh -atan2 -bitwise_not -bitwise_and -bitwise_or -bitwise_xor -bitwise_left_shift -bitwise_right_shift -ceil -clamp -clip -conj_physical -copysign -cos -cosh -deg2rad -div -divide -digamma -erf -erfc -erfinv -exp -exp2 -expm1 -fake_quantize_per_channel_affine -fake_quantize_per_tensor_affine -fix -float_power -floor -floor_divide -fmod -frac -frexp -gradient -imag -ldexp -lerp -lgamma -log -log10 -log1p -log2 -logaddexp -logaddexp2 -logical_and -logical_not -logical_or -logical_xor -logit -hypot -i0 -igamma -igammac -mul -multiply -mvlgamma -nan_to_num -neg -negative -nextafter -polygamma -positive -pow -quantized_batch_norm -quantized_max_pool1d -quantized_max_pool2d -rad2deg -real -reciprocal -remainder -round -rsqrt -sigmoid -sign -sgn -signbit -sin -sinc -sinh -sqrt -square -sub -subtract -tan -tanh -true_divide -trunc -xlogy -""" - -reduction_ops=""" -argmax -argmin -amax -amin -aminmax -all -any -max -min -dist -logsumexp -mean -nanmean -median -nanmedian -mode -norm -nansum -prod -quantile -nanquantile -std -std_mean -sum -unique -unique_consecutive -var -var_mean -count_nonzero -""" - -comparison_ops=""" -allclose -argsort -eq -equal -ge -greater_equal -gt -greater -isclose -isfinite -isin -isinf -isposinf -isneginf -isnan -isreal -kthvalue -le -less_equal -lt -less -maximum -minimum -fmax -fmin -ne -not_equal -sort -topk -msort -""" - -other_ops=""" -atleast_1d -atleast_2d -atleast_3d -bincount -block_diag -broadcast_tensors -broadcast_to -broadcast_shapes -bucketize -cartesian_prod -cdist -clone -combinations -corrcoef -cov -cross -cummax -cummin -cumprod -cumsum -diag -diag_embed -diagflat -diagonal -diff -einsum -flatten -flip -fliplr -flipud -kron -rot90 -gcd -histc -histogram -meshgrid -lcm -logcumsumexp -ravel -renorm -repeat_interleave -roll -searchsorted -tensordot -trace -tril -tril_indices -triu -triu_indices -vander -view_as_real -view_as_complex -resolve_conj -resolve_neg -""" - -creation_ops=[op.split('.')[-1] for op in creation_ops.split('\n') if op] -manipulation_ops=[op.split('.')[-1] for op in manipulation_ops.split('\n') if op] -convolution_ops=[op.split('.')[-1] for op in convolution_ops.split('\n') if op] -pooling_ops=[op.split('.')[-1] for op in pooling_ops.split('\n') if op] -activation_ops=[op.split('.')[-1] for op in activation_ops.split('\n') if op] -normalization_ops=[op.split('.')[-1] for op in normalization_ops.split('\n') if op] -linear_ops=[op.split('.')[-1] for op in linear_ops.split('\n') if op] -dropout_ops=[op.split('.')[-1] for op in dropout_ops.split('\n') if op] -vision_ops=[op.split('.')[-1] for op in vision_ops.split('\n') if op] -math_ops=[op.split('.')[-1] for op in math_ops.split('\n') if op] -reduction_ops=[op.split('.')[-1] for op in reduction_ops.split('\n') if op] -comparison_ops=[op.split('.')[-1] for op in comparison_ops.split('\n') if op] -other_ops=[op.split('.')[-1] for op in other_ops.split('\n') if op] - -ops_color={} -ops_color.update({op : '#707070' for op in creation_ops}) -ops_color.update({op : '#803080' for op in manipulation_ops}) -ops_color.update({op : '#3080c0' for op in convolution_ops}) -ops_color.update({op : '#109010' for op in pooling_ops}) -ops_color.update({op : '#b03030' for op in activation_ops}) -ops_color.update({op : '#6080a0' for op in normalization_ops}) -ops_color.update({op : '#30b060' for op in linear_ops}) -ops_color.update({op : '#c09020' for op in dropout_ops}) -ops_color.update({op : '#509090' for op in vision_ops}) -ops_color.update({op : '#d06000' for op in math_ops}) -ops_color.update({op : '#906000' for op in reduction_ops}) -ops_color.update({op : '#90a060' for op in comparison_ops}) -ops_color.update({op : '#b03070' for op in other_ops}) - -text_color="#f0f0f0" -default_color="#908070"; -input_color="#606060" -output_color="#808080" - -link_text_color="#d0d0d0" -link_color="#a0a0a0" - -app_socket = tc.connect() -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - - if msg_type == 'SetGraph': - nodes=eval(str(msg_data,'utf-8')) - - if msg_type == 'Render': - batch, legend = tc.decode_ints(msg_data) - filtered_nodes=copy.deepcopy(nodes) - - #merge referenced getitems - for id, node in nodes.items(): - filtered_nodes[id]['input_shape']={} - for input in node['inputs']: - if nodes[input]['op_module']=='operator' and nodes[input]['op']=='getitem': - for sub_input in nodes[input]['inputs']: - filtered_nodes[id]['inputs'].remove(input) - filtered_nodes[id]['inputs'].append(sub_input) - filtered_nodes[id]['input_shape'][sub_input]=nodes[input]['output_shape'] - del filtered_nodes[input] - #del non-referenced getitems - nodes=copy.deepcopy(filtered_nodes) - for id, node in nodes.items(): - if node['op_module']=='operator' and node['op']=='getitem': - del filtered_nodes[id] - graph = graphviz.Digraph(graph_attr={'peripheries':'0', 'dpi': '0.0', 'bgcolor': 'transparent', 'ranksep': '0.25', 'margin': '0'}, - node_attr={'style': 'filled', 'shape': 'Mrecord', 'fillcolor': default_color, 'penwidth':'0', 'fontcolor': text_color,'fontsize':'20', 'fontname':'Source Code Pro'}, - edge_attr={'color': link_color, 'fontcolor': link_text_color,'fontsize':'16', 'fontname':'Source Code Pro'}) - inputs_graph = graphviz.Digraph(name='cluster_input', node_attr={'shape': 'oval', 'fillcolor': input_color, 'margin': '0'}) - outputs_graph = graphviz.Digraph(name='cluster_output', node_attr={'shape': 'oval', 'fillcolor': output_color, 'margin': '0'}) - - for id, node in filtered_nodes.items(): - if node['type']=='input': - inputs_graph.node(id, '<'+node['name']+'>', tooltip=node['name']) - elif node['type']=='output': - outputs_graph.node(id, '<'+node['name']+'>') - else: - if node['op'] in ops_color: - node_color=ops_color[node['op']] - else: - node_color=default_color - label=node['op'] - label_start, label_end=('<','>') if node['type']=='function' else ('','') - if node['op']=='': - node_tooltip=node['name'] - else: - node_tooltip=(node['name']+' = ' if node['type']=='module' else '')+node['op_module']+"."+node['op']+'('+node['params']+')' - graph.node(id, label_start+label+label_end, {'fillcolor': node_color, 'tooltip': node_tooltip}) - - graph.subgraph(inputs_graph) - graph.subgraph(outputs_graph) - - for id, node in filtered_nodes.items(): - for input in node['inputs']: - output_shape=filtered_nodes[input]['output_shape'] - if input in node['input_shape']: - output_shape=node['input_shape'][input] - if batch==1: - output_shape=('N,' if output_shape else 'N')+output_shape - graph.edge(input,id," "+output_shape.replace(',','\u00d7')) #replace comma by multiplication sign - - if legend==1: - with graph.subgraph(name='cluster_legend', node_attr={'shape': 'box', 'margin':'0', 'fontsize':'16', 'style':''}) as legend: - table= 'Input' - table+=' Creation' - table+='Manipulation' - table+=' Convolution' - table+='Pooling' - table+=' Activation' - table+='Normalization' - table+=' Linear' - table+='Dropout' - table+=' Vision' - table+='Math' - table+=' Reduction' - table+='Comparison' - table+=' Other' - table+='Unknown' - table+=' Output' - legend.node('legend', '<'+table+'
>') - - svg=graph.pipe(format='svg') - tc.send_msg(app_socket, 'SVGData', svg) - -# with open('/Users/divide/Documents/output.txt','w') as file: -# print(graph.source, file=file) -# with open('/Users/divide/Documents/output.svg','w') as file: -# print(str(svg, 'utf-8'), file=file) -# with open('/Users/divide/Documents/output.png','wb') as file: -# file.write(graph.pipe(format='png')) - - if msg_type == 'Exit': - break - +import torchstudio.tcpcodec as tc +import os +import graphviz +import copy + +#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/torch.rst +#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/nn.rst +#from https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/nn.functional.rst + +creation_ops=""" +tensor +sparse_coo_tensor +as_tensor +as_strided +from_numpy +frombuffer +zeros +zeros_like +ones +ones_like +arange +range +linspace +logspace +eye +empty +empty_like +empty_strided +full +full_like +quantize_per_tensor +quantize_per_channel +dequantize +complex +polar +heaviside +""" + +manipulation_ops=""" +cat +concat +conj +chunk +dsplit +column_stack +dstack +gather +hsplit +hstack +index_select +masked_select +movedim +moveaxis +narrow +nonzero +permute +reshape +row_stack +scatter +scatter_add +split +squeeze +stack +swapaxes +swapdims +t +take +take_along_dim +tensor_split +tile +transpose +unbind +unsqueeze +vsplit +vstack +where +""" + +convolution_ops=""" +nn.Conv1d +nn.Conv2d +nn.Conv3d +nn.ConvTranspose1d +nn.ConvTranspose2d +nn.ConvTranspose3d +nn.LazyConv1d +nn.LazyConv2d +nn.LazyConv3d +nn.LazyConvTranspose1d +nn.LazyConvTranspose2d +nn.LazyConvTranspose3d +nn.Unfold +nn.Fold +conv1d +conv2d +conv3d +conv_transpose1d +conv_transpose2d +conv_transpose3d +unfold +fold +""" + +pooling_ops=""" +nn.MaxPool1d +nn.MaxPool2d +nn.MaxPool3d +nn.MaxUnpool1d +nn.MaxUnpool2d +nn.MaxUnpool3d +nn.AvgPool1d +nn.AvgPool2d +nn.AvgPool3d +nn.FractionalMaxPool2d +nn.FractionalMaxPool3d +nn.LPPool1d +nn.LPPool2d +nn.AdaptiveMaxPool1d +nn.AdaptiveMaxPool2d +nn.AdaptiveMaxPool3d +nn.AdaptiveAvgPool1d +nn.AdaptiveAvgPool2d +nn.AdaptiveAvgPool3d +avg_pool1d +avg_pool2d +avg_pool3d +max_pool1d +max_pool2d +max_pool3d +max_unpool1d +max_unpool2d +max_unpool3d +lp_pool1d +lp_pool2d +adaptive_max_pool1d +adaptive_max_pool2d +adaptive_max_pool3d +adaptive_avg_pool1d +adaptive_avg_pool2d +adaptive_avg_pool3d +fractional_max_pool2d +fractional_max_pool3d +""" + +activation_ops=""" +nn.ELU +nn.Hardshrink +nn.Hardsigmoid +nn.Hardtanh +nn.Hardswish +nn.LeakyReLU +nn.LogSigmoid +nn.MultiheadAttention +nn.PReLU +nn.ReLU +nn.ReLU6 +nn.RReLU +nn.SELU +nn.CELU +nn.GELU +nn.Sigmoid +nn.SiLU +nn.Mish +nn.Softplus +nn.Softshrink +nn.Softsign +nn.Tanh +nn.Tanhshrink +nn.Threshold +nn.GLU +nn.Softmin +nn.Softmax +nn.Softmax2d +nn.LogSoftmax +nn.AdaptiveLogSoftmaxWithLoss +threshold +threshold_ +relu +relu_ +hardtanh +hardtanh_ +hardswish +relu6 +elu +elu_ +selu +celu +leaky_relu +leaky_relu_ +prelu +rrelu +rrelu_ +glu +gelu +logsigmoid +hardshrink +tanhshrink +softsign +softplus +softmin +softmax +softshrink +gumbel_softmax +log_softmax +tanh +sigmoid +hardsigmoid +silu +mish +batch_norm +group_norm +instance_norm +layer_norm +local_response_norm +normalize +""" + +normalization_ops=""" +nn.BatchNorm1d +nn.BatchNorm2d +nn.BatchNorm3d +nn.LazyBatchNorm1d +nn.LazyBatchNorm2d +nn.LazyBatchNorm3d +nn.GroupNorm +nn.SyncBatchNorm +nn.InstanceNorm1d +nn.InstanceNorm2d +nn.InstanceNorm3d +nn.LazyInstanceNorm1d +nn.LazyInstanceNorm2d +nn.LazyInstanceNorm3d +nn.LayerNorm +nn.LocalResponseNorm +""" + +linear_ops=""" +nn.Identity +nn.Linear +nn.Bilinear +nn.LazyLinear +linear +bilinear +""" + +dropout_ops=""" +nn.Dropout +nn.Dropout2d +nn.Dropout3d +nn.AlphaDropout +nn.FeatureAlphaDropout +dropout +alpha_dropout +feature_alpha_dropout +dropout2d +dropout3d +""" + +vision_ops=""" +nn.PixelShuffle +nn.PixelUnshuffle +nn.Upsample +nn.UpsamplingNearest2d +nn.UpsamplingBilinear2d +pixel_shuffle +pixel_unshuffle +pad +interpolate +upsample +upsample_nearest +upsample_bilinear +grid_sample +affine_grid +""" + +math_ops=""" +abs +absolute +acos +arccos +acosh +arccosh +add +addcdiv +addcmul +angle +asin +arcsin +asinh +arcsinh +atan +arctan +atanh +arctanh +atan2 +bitwise_not +bitwise_and +bitwise_or +bitwise_xor +bitwise_left_shift +bitwise_right_shift +ceil +clamp +clip +conj_physical +copysign +cos +cosh +deg2rad +div +divide +digamma +erf +erfc +erfinv +exp +exp2 +expm1 +fake_quantize_per_channel_affine +fake_quantize_per_tensor_affine +fix +float_power +floor +floor_divide +fmod +frac +frexp +gradient +imag +ldexp +lerp +lgamma +log +log10 +log1p +log2 +logaddexp +logaddexp2 +logical_and +logical_not +logical_or +logical_xor +logit +hypot +i0 +igamma +igammac +mul +multiply +mvlgamma +nan_to_num +neg +negative +nextafter +polygamma +positive +pow +quantized_batch_norm +quantized_max_pool1d +quantized_max_pool2d +rad2deg +real +reciprocal +remainder +round +rsqrt +sigmoid +sign +sgn +signbit +sin +sinc +sinh +sqrt +square +sub +subtract +tan +tanh +true_divide +trunc +xlogy +""" + +reduction_ops=""" +argmax +argmin +amax +amin +aminmax +all +any +max +min +dist +logsumexp +mean +nanmean +median +nanmedian +mode +norm +nansum +prod +quantile +nanquantile +std +std_mean +sum +unique +unique_consecutive +var +var_mean +count_nonzero +""" + +comparison_ops=""" +allclose +argsort +eq +equal +ge +greater_equal +gt +greater +isclose +isfinite +isin +isinf +isposinf +isneginf +isnan +isreal +kthvalue +le +less_equal +lt +less +maximum +minimum +fmax +fmin +ne +not_equal +sort +topk +msort +""" + +other_ops=""" +atleast_1d +atleast_2d +atleast_3d +bincount +block_diag +broadcast_tensors +broadcast_to +broadcast_shapes +bucketize +cartesian_prod +cdist +clone +combinations +corrcoef +cov +cross +cummax +cummin +cumprod +cumsum +diag +diag_embed +diagflat +diagonal +diff +einsum +flatten +flip +fliplr +flipud +kron +rot90 +gcd +histc +histogram +meshgrid +lcm +logcumsumexp +ravel +renorm +repeat_interleave +roll +searchsorted +tensordot +trace +tril +tril_indices +triu +triu_indices +vander +view_as_real +view_as_complex +resolve_conj +resolve_neg +""" + +creation_ops=[op.split('.')[-1] for op in creation_ops.split('\n') if op] +manipulation_ops=[op.split('.')[-1] for op in manipulation_ops.split('\n') if op] +convolution_ops=[op.split('.')[-1] for op in convolution_ops.split('\n') if op] +pooling_ops=[op.split('.')[-1] for op in pooling_ops.split('\n') if op] +activation_ops=[op.split('.')[-1] for op in activation_ops.split('\n') if op] +normalization_ops=[op.split('.')[-1] for op in normalization_ops.split('\n') if op] +linear_ops=[op.split('.')[-1] for op in linear_ops.split('\n') if op] +dropout_ops=[op.split('.')[-1] for op in dropout_ops.split('\n') if op] +vision_ops=[op.split('.')[-1] for op in vision_ops.split('\n') if op] +math_ops=[op.split('.')[-1] for op in math_ops.split('\n') if op] +reduction_ops=[op.split('.')[-1] for op in reduction_ops.split('\n') if op] +comparison_ops=[op.split('.')[-1] for op in comparison_ops.split('\n') if op] +other_ops=[op.split('.')[-1] for op in other_ops.split('\n') if op] + +ops_color={} +ops_color.update({op : '#707070' for op in creation_ops}) +ops_color.update({op : '#803080' for op in manipulation_ops}) +ops_color.update({op : '#3080c0' for op in convolution_ops}) +ops_color.update({op : '#109010' for op in pooling_ops}) +ops_color.update({op : '#b03030' for op in activation_ops}) +ops_color.update({op : '#6080a0' for op in normalization_ops}) +ops_color.update({op : '#30b060' for op in linear_ops}) +ops_color.update({op : '#c09020' for op in dropout_ops}) +ops_color.update({op : '#509090' for op in vision_ops}) +ops_color.update({op : '#d06000' for op in math_ops}) +ops_color.update({op : '#906000' for op in reduction_ops}) +ops_color.update({op : '#90a060' for op in comparison_ops}) +ops_color.update({op : '#b03070' for op in other_ops}) + +text_color="#f0f0f0" +default_color="#908070"; +input_color="#606060" +output_color="#808080" + +link_text_color="#d0d0d0" +link_color="#a0a0a0" + +app_socket = tc.connect() +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + + if msg_type == 'SetGraph': + nodes=eval(str(msg_data,'utf-8')) + + if msg_type == 'Render': + batch, legend = tc.decode_ints(msg_data) + filtered_nodes=copy.deepcopy(nodes) + + #merge referenced getitems + for id, node in nodes.items(): + filtered_nodes[id]['input_shape']={} + for input in node['inputs']: + if nodes[input]['op_module']=='operator' and nodes[input]['op']=='getitem': + for sub_input in nodes[input]['inputs']: + filtered_nodes[id]['inputs'].remove(input) + filtered_nodes[id]['inputs'].append(sub_input) + filtered_nodes[id]['input_shape'][sub_input]=nodes[input]['output_shape'] + del filtered_nodes[input] + #del non-referenced getitems + nodes=copy.deepcopy(filtered_nodes) + for id, node in nodes.items(): + if node['op_module']=='operator' and node['op']=='getitem': + del filtered_nodes[id] + graph = graphviz.Digraph(graph_attr={'peripheries':'0', 'dpi': '0.0', 'bgcolor': 'transparent', 'ranksep': '0.25', 'margin': '0'}, + node_attr={'style': 'filled', 'shape': 'Mrecord', 'fillcolor': default_color, 'penwidth':'0', 'fontcolor': text_color,'fontsize':'20', 'fontname':'Source Code Pro'}, + edge_attr={'color': link_color, 'fontcolor': link_text_color,'fontsize':'16', 'fontname':'Source Code Pro'}) + inputs_graph = graphviz.Digraph(name='cluster_input', node_attr={'shape': 'oval', 'fillcolor': input_color, 'margin': '0'}) + outputs_graph = graphviz.Digraph(name='cluster_output', node_attr={'shape': 'oval', 'fillcolor': output_color, 'margin': '0'}) + + for id, node in filtered_nodes.items(): + if node['type']=='input': + inputs_graph.node(id, '<'+node['name']+'>', tooltip=node['name']) + elif node['type']=='output': + outputs_graph.node(id, '<'+node['name']+'>') + else: + if node['op'] in ops_color: + node_color=ops_color[node['op']] + else: + node_color=default_color + label=node['op'] + label_start, label_end=('<','>') if node['type']=='function' else ('','') + if node['op']=='': + node_tooltip=node['name'] + else: + node_tooltip=(node['name']+' = ' if node['type']=='module' else '')+node['op_module']+"."+node['op']+'('+node['params']+')' + graph.node(id, label_start+label+label_end, {'fillcolor': node_color, 'tooltip': node_tooltip}) + + graph.subgraph(inputs_graph) + graph.subgraph(outputs_graph) + + for id, node in filtered_nodes.items(): + for input in node['inputs']: + output_shape=filtered_nodes[input]['output_shape'] + if input in node['input_shape']: + output_shape=node['input_shape'][input] + if batch==1: + output_shape=('N,' if output_shape else 'N')+output_shape + graph.edge(input,id," "+output_shape.replace(',','\u00d7')) #replace comma by multiplication sign + + if legend==1: + with graph.subgraph(name='cluster_legend', node_attr={'shape': 'box', 'margin':'0', 'fontsize':'16', 'style':''}) as legend: + table= 'Input' + table+=' Creation' + table+='Manipulation' + table+=' Convolution' + table+='Pooling' + table+=' Activation' + table+='Normalization' + table+=' Linear' + table+='Dropout' + table+=' Vision' + table+='Math' + table+=' Reduction' + table+='Comparison' + table+=' Other' + table+='Unknown' + table+=' Output' + legend.node('legend', '<'+table+'
>') + + svg=graph.pipe(format='svg') + tc.send_msg(app_socket, 'SVGData', svg) + +# with open('/Users/divide/Documents/output.txt','w') as file: +# print(graph.source, file=file) +# with open('/Users/divide/Documents/output.svg','w') as file: +# print(str(svg, 'utf-8'), file=file) +# with open('/Users/divide/Documents/output.png','wb') as file: +# file.write(graph.pipe(format='png')) + + if msg_type == 'Exit': + break + diff --git a/torchstudio/metricsplot.py b/torchstudio/metricsplot.py index 0b6c665..6c29630 100644 --- a/torchstudio/metricsplot.py +++ b/torchstudio/metricsplot.py @@ -1,185 +1,185 @@ -import torchstudio.tcpcodec as tc -import inspect -import sys -import os - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.ticker import MaxNLocator -import PIL - -def plot_metrics(prefix, size, dpi, samples=100, labels=[], - loss=[], loss_colors=[], loss_shift=(0,0), loss_scale=(1,1), - metric=[], metric_colors=[], metric_shift=(0,0), metric_scale=(1,1)): - """Metrics Plot - - Usage: - Drag: pan - Scroll: zoom vertically - """ - #set up matplotlib renderer, style, figure and axis - mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/ - plt.style.use('dark_background') - plt.rcParams.update({'font.size': 7}) - - fig, [ax1, ax2] = plt.subplots(1 if size[0]>size[1] else 2, 2 if size[0]>size[1] else 1, figsize=(size[0]/dpi, size[1]/dpi), dpi=dpi) - - #LOSS - ax1.set_title(prefix+"Loss") - - #fit - loss_xmin=0 - loss_xmax=samples - loss_ymin=0 - loss_ymax=1 - for l in loss: - loss_xmax=max(loss_xmax,len(l)) -# if(len(l)>0): -# loss_ymax=max(loss_ymax,max(l)) - -# #shift -# render_size=(loss_xmax-loss_xmin,loss_ymax-loss_ymin) -# loss_xmin-=loss_shift[0]/loss_scale[0]*render_size[0] -# loss_xmax-=loss_shift[0]/loss_scale[0]*render_size[0] -# loss_ymin-=loss_shift[1]/loss_scale[1]*render_size[1] -# loss_ymax-=loss_shift[1]/loss_scale[1]*render_size[1] - -# #scale -# render_center=(loss_xmin+render_size[0]/2,loss_ymin+render_size[1]/2) -# loss_xmin=render_center[0]-(render_size[0]/loss_scale[0]/2) -# loss_xmax=render_center[0]+(render_size[0]/loss_scale[0]/2) -# loss_ymin=render_center[1]-(render_size[1]/loss_scale[1]/2) -# loss_ymax=render_center[1]+(render_size[1]/loss_scale[1]/2) - -# loss_xmin=max(0,loss_xmin) -# loss_ymin=max(0,loss_ymin) - - loss_ymin-=loss_shift[1]/loss_scale[1] - loss_ymax-=loss_shift[1]/loss_scale[1] - loss_ymax=loss_ymax/loss_scale[1] - - ax1.axis(xmin=loss_xmin,xmax=loss_xmax,ymin=loss_ymin,ymax=loss_ymax) - ax1.spines['top'].set_visible(False) - ax1.spines['right'].set_visible(False) - ax1.spines['left'].set_color('#707070') - ax1.spines['bottom'].set_color('#707070') - for i in range(len(loss)): - ax1.plot(loss[i],label=str(i) if i>=len(labels) else labels[i],color=loss_colors[i%len(loss_colors)]) - if labels and loss and loss[0]: - ax1.legend(bbox_to_anchor=(1, 1), loc='upper right', ncol=1, prop={'size': 8}) - ax1.grid(color = '#303030') - ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) - ax1.ticklabel_format(axis="y", style="sci", scilimits=(0,0)) - - #METRIC - ax2.set_title(prefix+"Metric") - - #fit - metric_xmin=0 - metric_xmax=samples - metric_ymin=0 - metric_ymax=1 - for m in metric: - metric_xmax=max(metric_xmax,len(m)) - -# #shift -# render_size=(metric_xmax-metric_xmin,metric_ymax-metric_ymin) -# metric_xmin-=metric_shift[0]/metric_scale[0]*render_size[0] -# metric_xmax-=metric_shift[0]/metric_scale[0]*render_size[0] -# metric_ymin-=metric_shift[1]/metric_scale[1]*render_size[1] -# metric_ymax-=metric_shift[1]/metric_scale[1]*render_size[1] - -# #scale -# render_center=(metric_xmin+render_size[0]/2,metric_ymin+render_size[1]/2) -# metric_xmin=render_center[0]-(render_size[0]/metric_scale[0]/2) -# metric_xmax=render_center[0]+(render_size[0]/metric_scale[0]/2) -# metric_ymin=render_center[1]-(render_size[1]/metric_scale[1]/2) -# metric_ymax=render_center[1]+(render_size[1]/metric_scale[1]/2) - -# metric_xmin=max(0,metric_xmin) - - metric_ymin-=metric_shift[1]/metric_scale[1] - metric_ymax-=metric_shift[1]/metric_scale[1] - metric_ymin=(metric_ymin-metric_ymax)/metric_scale[1]+metric_ymax - - ax2.axis(xmin=metric_xmin,xmax=metric_xmax,ymin=metric_ymin,ymax=metric_ymax) - ax2.spines['top'].set_visible(False) - ax2.spines['right'].set_visible(False) - ax2.spines['left'].set_color('#707070') - ax2.spines['bottom'].set_color('#707070') - for i in range(len(metric)): - ax2.plot(metric[i],color=metric_colors[i%len(metric_colors)]) - ax2.grid(color = '#303030') - ax2.xaxis.set_major_locator(MaxNLocator(integer=True)) - - plt.tight_layout(pad=0) - - canvas = plt.get_current_fig_manager().canvas - canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) - plt.close() - return img - - -prefix = '' -resolution = (256,256, 96) -samples=100 -labels = [] - -loss=[] -loss_colors=[] -loss_shift = (0,0) -loss_scale = (1,1) - -metric=[] -metric_colors=[] -metric_labels = [] -metric_shift = (0,0) -metric_scale = (1,1) - -app_socket = tc.connect() -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - - if msg_type == 'RequestDocumentation': - tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_metrics.__doc__))) - if msg_type == 'SetPrefix': - prefix=tc.decode_strings(msg_data)[0] - - if msg_type == 'SetResolution': - resolution = tc.decode_ints(msg_data) - - if msg_type == 'NumSamples': - samples = tc.decode_ints(msg_data)[0] - if msg_type == 'SetLabels': - labels=tc.decode_strings(msg_data) - - if msg_type == 'ClearLoss': - loss=[] - if msg_type == 'AppendLoss': - loss.append(tc.decode_floats(msg_data)) - if msg_type == 'SetLossColors': - loss_colors=tc.decode_strings(msg_data) - if msg_type == 'SetLossShift': - loss_shift = tc.decode_floats(msg_data) - if msg_type == 'SetLossScale': - loss_scale = tc.decode_floats(msg_data) - - if msg_type == 'ClearMetric': - metric=[] - if msg_type == 'AppendMetric': - metric.append(tc.decode_floats(msg_data)) - if msg_type == 'SetMetricColors': - metric_colors=tc.decode_strings(msg_data) - if msg_type == 'SetMetricShift': - metric_shift = tc.decode_floats(msg_data) - if msg_type == 'SetMetricScale': - metric_scale = tc.decode_floats(msg_data) - - if msg_type == 'Render': - if resolution[0]>0 and resolution[1]>0: - img=plot_metrics(prefix,resolution[0:2],resolution[2],samples,labels,loss,loss_colors,loss_shift,loss_scale,metric,metric_colors,metric_shift,metric_scale) - tc.send_msg(app_socket, 'ImageData', tc.encode_image(img)) - - if msg_type == 'Exit': - break +import torchstudio.tcpcodec as tc +import inspect +import sys +import os + +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.ticker import MaxNLocator +import PIL + +def plot_metrics(prefix, size, dpi, samples=100, labels=[], + loss=[], loss_colors=[], loss_shift=(0,0), loss_scale=(1,1), + metric=[], metric_colors=[], metric_shift=(0,0), metric_scale=(1,1)): + """Metrics Plot + + Usage: + Drag: pan + Scroll: zoom vertically + """ + #set up matplotlib renderer, style, figure and axis + mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/ + plt.style.use('dark_background') + plt.rcParams.update({'font.size': 7}) + + fig, [ax1, ax2] = plt.subplots(1 if size[0]>size[1] else 2, 2 if size[0]>size[1] else 1, figsize=(size[0]/dpi, size[1]/dpi), dpi=dpi) + + #LOSS + ax1.set_title(prefix+"Loss") + + #fit + loss_xmin=0 + loss_xmax=samples + loss_ymin=0 + loss_ymax=1 + for l in loss: + loss_xmax=max(loss_xmax,len(l)) +# if(len(l)>0): +# loss_ymax=max(loss_ymax,max(l)) + +# #shift +# render_size=(loss_xmax-loss_xmin,loss_ymax-loss_ymin) +# loss_xmin-=loss_shift[0]/loss_scale[0]*render_size[0] +# loss_xmax-=loss_shift[0]/loss_scale[0]*render_size[0] +# loss_ymin-=loss_shift[1]/loss_scale[1]*render_size[1] +# loss_ymax-=loss_shift[1]/loss_scale[1]*render_size[1] + +# #scale +# render_center=(loss_xmin+render_size[0]/2,loss_ymin+render_size[1]/2) +# loss_xmin=render_center[0]-(render_size[0]/loss_scale[0]/2) +# loss_xmax=render_center[0]+(render_size[0]/loss_scale[0]/2) +# loss_ymin=render_center[1]-(render_size[1]/loss_scale[1]/2) +# loss_ymax=render_center[1]+(render_size[1]/loss_scale[1]/2) + +# loss_xmin=max(0,loss_xmin) +# loss_ymin=max(0,loss_ymin) + + loss_ymin-=loss_shift[1]/loss_scale[1] + loss_ymax-=loss_shift[1]/loss_scale[1] + loss_ymax=loss_ymax/loss_scale[1] + + ax1.axis(xmin=loss_xmin,xmax=loss_xmax,ymin=loss_ymin,ymax=loss_ymax) + ax1.spines['top'].set_visible(False) + ax1.spines['right'].set_visible(False) + ax1.spines['left'].set_color('#707070') + ax1.spines['bottom'].set_color('#707070') + for i in range(len(loss)): + ax1.plot(loss[i],label=str(i) if i>=len(labels) else labels[i],color=loss_colors[i%len(loss_colors)]) + if labels and loss and loss[0]: + ax1.legend(bbox_to_anchor=(1, 1), loc='upper right', ncol=1, prop={'size': 8}) + ax1.grid(color = '#303030') + ax1.xaxis.set_major_locator(MaxNLocator(nbins='auto', integer=True)) + ax1.ticklabel_format(axis="y", style="sci", scilimits=(0,0)) + + #METRIC + ax2.set_title(prefix+"Metric") + + #fit + metric_xmin=0 + metric_xmax=samples + metric_ymin=0 + metric_ymax=1 + for m in metric: + metric_xmax=max(metric_xmax,len(m)) + +# #shift +# render_size=(metric_xmax-metric_xmin,metric_ymax-metric_ymin) +# metric_xmin-=metric_shift[0]/metric_scale[0]*render_size[0] +# metric_xmax-=metric_shift[0]/metric_scale[0]*render_size[0] +# metric_ymin-=metric_shift[1]/metric_scale[1]*render_size[1] +# metric_ymax-=metric_shift[1]/metric_scale[1]*render_size[1] + +# #scale +# render_center=(metric_xmin+render_size[0]/2,metric_ymin+render_size[1]/2) +# metric_xmin=render_center[0]-(render_size[0]/metric_scale[0]/2) +# metric_xmax=render_center[0]+(render_size[0]/metric_scale[0]/2) +# metric_ymin=render_center[1]-(render_size[1]/metric_scale[1]/2) +# metric_ymax=render_center[1]+(render_size[1]/metric_scale[1]/2) + +# metric_xmin=max(0,metric_xmin) + + metric_ymin-=metric_shift[1]/metric_scale[1] + metric_ymax-=metric_shift[1]/metric_scale[1] + metric_ymin=(metric_ymin-metric_ymax)/metric_scale[1]+metric_ymax + + ax2.axis(xmin=metric_xmin,xmax=metric_xmax,ymin=metric_ymin,ymax=metric_ymax) + ax2.spines['top'].set_visible(False) + ax2.spines['right'].set_visible(False) + ax2.spines['left'].set_color('#707070') + ax2.spines['bottom'].set_color('#707070') + for i in range(len(metric)): + ax2.plot(metric[i],color=metric_colors[i%len(metric_colors)]) + ax2.grid(color = '#303030') + ax2.xaxis.set_major_locator(MaxNLocator(nbins='auto', integer=True)) + + plt.tight_layout(pad=0) + + canvas = plt.get_current_fig_manager().canvas + canvas.draw() + img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + plt.close() + return img + + +prefix = '' +resolution = (256,256, 96) +samples=100 +labels = [] + +loss=[] +loss_colors=[] +loss_shift = (0,0) +loss_scale = (1,1) + +metric=[] +metric_colors=[] +metric_labels = [] +metric_shift = (0,0) +metric_scale = (1,1) + +app_socket = tc.connect() +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + + if msg_type == 'RequestDocumentation': + tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_metrics.__doc__))) + if msg_type == 'SetPrefix': + prefix=tc.decode_strings(msg_data)[0] + + if msg_type == 'SetResolution': + resolution = tc.decode_ints(msg_data) + + if msg_type == 'NumSamples': + samples = tc.decode_ints(msg_data)[0] + if msg_type == 'SetLabels': + labels=tc.decode_strings(msg_data) + + if msg_type == 'ClearLoss': + loss=[] + if msg_type == 'AppendLoss': + loss.append(tc.decode_floats(msg_data)) + if msg_type == 'SetLossColors': + loss_colors=tc.decode_strings(msg_data) + if msg_type == 'SetLossShift': + loss_shift = tc.decode_floats(msg_data) + if msg_type == 'SetLossScale': + loss_scale = tc.decode_floats(msg_data) + + if msg_type == 'ClearMetric': + metric=[] + if msg_type == 'AppendMetric': + metric.append(tc.decode_floats(msg_data)) + if msg_type == 'SetMetricColors': + metric_colors=tc.decode_strings(msg_data) + if msg_type == 'SetMetricShift': + metric_shift = tc.decode_floats(msg_data) + if msg_type == 'SetMetricScale': + metric_scale = tc.decode_floats(msg_data) + + if msg_type == 'Render': + if resolution[0]>0 and resolution[1]>0: + img=plot_metrics(prefix,resolution[0:2],resolution[2],samples,labels,loss,loss_colors,loss_shift,loss_scale,metric,metric_colors,metric_shift,metric_scale) + tc.send_msg(app_socket, 'ImageData', tc.encode_image(img)) + + if msg_type == 'Exit': + break diff --git a/torchstudio/modelbuild.py b/torchstudio/modelbuild.py index ad9a079..5f04c3e 100644 --- a/torchstudio/modelbuild.py +++ b/torchstudio/modelbuild.py @@ -1,246 +1,246 @@ -#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490 -import os -os.environ['KMP_DUPLICATE_LIB_OK']='True' - -import sys -print("Loading PyTorch...\n", file=sys.stderr) - -import torch -import torch.fx -from torch.fx.passes.shape_prop import ShapeProp -from torch.fx.graph_module import GraphModule -import torchstudio.tcpcodec as tc -from torchstudio.modules import safe_exec -import sys -import os -import io -import re -import graphviz -import linecache -import inspect - -#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error -import ssl -ssl._create_default_https_context = ssl._create_unverified_context - -original_path=sys.path -original_dir=os.getcwd() - -level=0 -max_depth=0 - -app_socket = tc.connect() -print("Build script connected\n", file=sys.stderr) -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - - if msg_type == 'SetCurrentDir': - new_dir=tc.decode_strings(msg_data)[0] - sys.path=original_path - os.chdir(original_dir) - if new_dir: - sys.path.append(new_dir) - os.chdir(new_dir) - - if msg_type == 'SetDataDir': - data_dir=tc.decode_strings(msg_data)[0] - torch.hub.set_dir(data_dir) - - if msg_type == 'SetModelCode': - model_code=tc.decode_strings(msg_data)[0] - - #create a module space for the model definition - #see https://stackoverflow.com/questions/5122465/can-i-fake-a-package-or-at-least-a-module-in-python-for-testing-purposes/27476659#27476659 - from types import ModuleType - modelmodule = ModuleType("modelmodule") - modelmodule.__file__ = modelmodule.__name__ + ".py" - sys.modules[modelmodule.__name__] = modelmodule - - error_msg, model_env = safe_exec(model_code, context=vars(modelmodule), output=vars(modelmodule), description='model definition') - if error_msg is not None or 'model' not in model_env: - print("Unknown model definition error" if error_msg is None else error_msg, file=sys.stderr) - - if msg_type == 'InputTensorsID': - input_tensors = tc.decode_torch_tensors(msg_data) - for i, tensor in enumerate(input_tensors): - input_tensors[i]=torch.unsqueeze(tensor, 0) #add batch dimension - - if msg_type == 'OutputTensorsID': - output_tensors = tc.decode_torch_tensors(msg_data) - for i, tensor in enumerate(output_tensors): - output_tensors[i]=torch.unsqueeze(tensor, 0) #add batch dimension - - if msg_type == 'SetLabels': - labels=tc.decode_strings(msg_data) - - if msg_type == 'Build': #generate the torchscript, graph, and suggest hyperparameters - if 'model' in model_env and input_tensors and output_tensors: - print("Building model...\n", file=sys.stderr) - - build_mode=tc.decode_strings(msg_data)[0] - - buffer=io.BytesIO() - torchscript_model=None - if build_mode=='package': #packaging - with torch.package.PackageExporter(buffer) as exp: - intern_list=[] - for path in os.listdir(): - if path.endswith(".py") and os.path.isfile(path): - intern_list.append(path[:-3]+".**") - if os.path.isdir(path): - intern_list.append(path+".**") - exp.extern('**',exclude=intern_list) - exp.intern(intern_list) - exp.save_source_string(modelmodule.__name__, model_code) - exp.save_pickle('model', 'model.pkl', modelmodule.model) - elif build_mode=='script': #scripting - #monkey patch linecache.getlines so that inspect.getsource called by torch.jit.script can work with a module coming from a string and not a file - def monkey_patch(filename, module_globals=None): - if filename == '': - return model_code.splitlines(keepends=True) - else: - return getlines(filename, module_globals) - getlines = linecache.getlines - linecache.getlines = monkey_patch - error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':modelmodule.model}, description='model scripting') - linecache.getlines = getlines - else: #tracing - error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':modelmodule.model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing') - - if error_msg: - print(error_msg, file=sys.stderr) - else: - if torchscript_model: - torch.jit.save(torchscript_model,buffer) - tc.send_msg(app_socket, 'TorchScriptData', buffer.getvalue()) - else: - tc.send_msg(app_socket, 'PackageData', buffer.getvalue()) - - print("Building graph...\n", file=sys.stderr) - - level=0 - max_depth=0 - - while level<=max_depth: - class LevelTracer(torch.fx.Tracer): - def is_leaf_module(self, m, qualname): - depth=re.sub(r'.[0-9]+', '', qualname).count('.') - if super().is_leaf_module(m, qualname)==False: - depth=depth+1 - global max_depth - max_depth=max(max_depth,depth) - if depth>max_depth-level: - return True - else: - return super().is_leaf_module(m, qualname) - - def level_trace(root): - tracer = LevelTracer() - graph = tracer.trace(root) - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - return GraphModule(tracer.root, graph, name) - - error_msg, gm = safe_exec(level_trace,(model_env['model'],), description='model graph') - if error_msg or gm is None: - print("Unknown model graph error" if error_msg is None else error_msg, file=sys.stderr) - else: - modules = dict(gm.named_modules()) - ShapeProp(gm).propagate(*input_tensors) - - parsed_nodes={} - for rank, node in enumerate(gm.graph.nodes): - id=node.name - name=node.name - inputs=[str(i) for i in list(node.all_input_nodes)] - output_dtype='' - output_shape='' - if 'tensor_meta' in node.meta: - if type(node.meta['tensor_meta']) is tuple or type(node.meta['tensor_meta']) is list: - for tensor_meta in node.meta['tensor_meta']: - output_dtype+=str(tensor_meta.dtype)+' ' - output_shape+=','.join([str(i) for i in list(tensor_meta.shape)[1:]])+' ' - output_dtype=output_dtype[:-1] - output_shape=output_shape[:-1] - else: - output_dtype=str(node.meta['tensor_meta'].dtype) - output_shape=','.join([str(i) for i in list(node.meta['tensor_meta'].shape)[1:]]) - - if node.op == 'placeholder': - node_type='input' - op_module='' - op='' - params='' - elif node.op == 'call_module': - node_type='module' - name=re.sub('\.([0-9]+)', r'[\1]', node.target) - op_module=modules[node.target].__module__ - op_module='torch.nn' #prefer this shortcut for all modules - op=modules[node.target].__class__.__name__ - params=modules[node.target].extra_repr() - elif node.op == 'call_function': - node_type='function' - op_module=node.target.__module__ if node.target.__module__ is not None else "torch" - op_module='operator' if op_module=='_operator' else op_module - op=node.target.__name__ - params_list=[str(x) for x in node.args] - params_list.extend([f'{key}={value}' for key, value in node._kwargs.items()]) - params=', '.join(params_list) - elif node.op == 'call_method': - node_type='function' - op_module="torch" - op=node.target - params_list=[str(x) for x in node.args] - params_list.extend([f'{key}={value}' for key, value in node._kwargs.items()]) - params=', '.join(params_list) - elif node.op == 'output': - node_type='output' - op_module='' - op='' - params='' - for input_node in node.all_input_nodes: - input_op="" - if input_node.op == 'call_module': - input_op=modules[input_node.target].__class__.__name__ - elif input_node.op == 'call_function': - input_op=input_node.target.__name__ - else: - node_type='unknown' - op_module='' - op='' - params='' - - if node_type=='output' and len(inputs)>1: - for i, input in enumerate(inputs): - parsed_nodes[id+"_"+str(i)]={'name':name+"["+str(i)+"]", 'type':node_type, 'op_module':op_module, 'op':op, 'params':params, 'output_dtype':output_dtype, 'output_shape':output_shape, 'inputs':inputs} - else: - parsed_nodes[id]={'name':name, 'type':node_type, 'op_module':op_module, 'op':op, 'params':params, 'output_dtype':output_dtype, 'output_shape':output_shape, 'inputs':inputs} - tc.send_msg(app_socket, 'GraphData', bytes(str(parsed_nodes),'utf-8')) - level+=1 - tc.send_msg(app_socket, 'GraphDataEnd') - - print("Model built ("+format(sum(p.numel() for p in model_env['model'].parameters() if p.requires_grad), ',d')+" parameters)") #from https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=robin_lobel - - #suggest loss names - loss=[] - for i, tensor in enumerate(output_tensors): - if len(tensor.shape)==1 and "int" in str(tensor.dtype): - #multiclass crossentropy classification - loss.append("CrossEntropy") - elif len(tensor.shape)==2 and tensor.shape[1]==len(labels): - #multiclass multilabel classification - loss.append("BinaryCrossEntropy") - else: - #default back to MSE for everything else - loss.append("MeanSquareError") - - #suggest metric names - metric=[] - for tensor in output_tensors: - metric.append("Accuracy") - - tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([128,0,100,1,1,1])) - tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['Adam','Step'])) - - if msg_type == 'Exit': - break - +#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490 +import os +os.environ['KMP_DUPLICATE_LIB_OK']='True' + +import sys +print("Loading PyTorch...\n", file=sys.stderr) + +import torch +import torch.fx +from torch.fx.passes.shape_prop import ShapeProp +from torch.fx.graph_module import GraphModule +import torchstudio.tcpcodec as tc +from torchstudio.modules import safe_exec +import sys +import os +import io +import re +import graphviz +import linecache +import inspect + +#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error +import ssl +ssl._create_default_https_context = ssl._create_unverified_context + +original_path=sys.path +original_dir=os.getcwd() + +level=0 +max_depth=0 + +app_socket = tc.connect() +print("Build script connected\n", file=sys.stderr) +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + + if msg_type == 'SetCurrentDir': + new_dir=tc.decode_strings(msg_data)[0] + sys.path=original_path + os.chdir(original_dir) + if new_dir: + sys.path.append(new_dir) + os.chdir(new_dir) + + if msg_type == 'SetDataDir': + data_dir=tc.decode_strings(msg_data)[0] + torch.hub.set_dir(data_dir) + + if msg_type == 'SetModelCode': + model_code=tc.decode_strings(msg_data)[0] + + #create a module space for the model definition + #see https://stackoverflow.com/questions/5122465/can-i-fake-a-package-or-at-least-a-module-in-python-for-testing-purposes/27476659#27476659 + from types import ModuleType + modelmodule = ModuleType("modelmodule") + modelmodule.__file__ = modelmodule.__name__ + ".py" + sys.modules[modelmodule.__name__] = modelmodule + + error_msg, model_env = safe_exec(model_code, context=vars(modelmodule), output=vars(modelmodule), description='model definition') + if error_msg is not None or 'model' not in model_env: + print("Unknown model definition error" if error_msg is None else error_msg, file=sys.stderr) + + if msg_type == 'InputTensorsID': + input_tensors = tc.decode_torch_tensors(msg_data) + for i, tensor in enumerate(input_tensors): + input_tensors[i]=torch.unsqueeze(tensor, 0) #add batch dimension + + if msg_type == 'OutputTensorsID': + output_tensors = tc.decode_torch_tensors(msg_data) + for i, tensor in enumerate(output_tensors): + output_tensors[i]=torch.unsqueeze(tensor, 0) #add batch dimension + + if msg_type == 'SetLabels': + labels=tc.decode_strings(msg_data) + + if msg_type == 'Build': #generate the torchscript, graph, and suggest hyperparameters + if 'model' in model_env and input_tensors and output_tensors: + print("Building model...\n", file=sys.stderr) + + build_mode=tc.decode_strings(msg_data)[0] + + buffer=io.BytesIO() + torchscript_model=None + if build_mode=='package': #packaging + with torch.package.PackageExporter(buffer) as exp: + intern_list=[] + for path in os.listdir(): + if path.endswith(".py") and os.path.isfile(path): + intern_list.append(path[:-3]+".**") + if os.path.isdir(path): + intern_list.append(path+".**") + exp.extern('**',exclude=intern_list) + exp.intern(intern_list) + exp.save_source_string(modelmodule.__name__, model_code) + exp.save_pickle('model', 'model.pkl', modelmodule.model) + elif build_mode=='script': #scripting + #monkey patch linecache.getlines so that inspect.getsource called by torch.jit.script can work with a module coming from a string and not a file + def monkey_patch(filename, module_globals=None): + if filename == '': + return model_code.splitlines(keepends=True) + else: + return getlines(filename, module_globals) + getlines = linecache.getlines + linecache.getlines = monkey_patch + error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':modelmodule.model}, description='model scripting') + linecache.getlines = getlines + else: #tracing + error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':modelmodule.model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing') + + if error_msg: + print(error_msg, file=sys.stderr) + else: + if torchscript_model: + torch.jit.save(torchscript_model,buffer) + tc.send_msg(app_socket, 'TorchScriptData', buffer.getvalue()) + else: + tc.send_msg(app_socket, 'PackageData', buffer.getvalue()) + + print("Building graph...\n", file=sys.stderr) + + level=0 + max_depth=0 + + while level<=max_depth: + class LevelTracer(torch.fx.Tracer): + def is_leaf_module(self, m, qualname): + depth=re.sub(r'.[0-9]+', '', qualname).count('.') + if super().is_leaf_module(m, qualname)==False: + depth=depth+1 + global max_depth + max_depth=max(max_depth,depth) + if depth>max_depth-level: + return True + else: + return super().is_leaf_module(m, qualname) + + def level_trace(root): + tracer = LevelTracer() + graph = tracer.trace(root) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return GraphModule(tracer.root, graph, name) + + error_msg, gm = safe_exec(level_trace,(model_env['model'],), description='model graph') + if error_msg or gm is None: + print("Unknown model graph error" if error_msg is None else error_msg, file=sys.stderr) + else: + modules = dict(gm.named_modules()) + ShapeProp(gm).propagate(*input_tensors) + + parsed_nodes={} + for rank, node in enumerate(gm.graph.nodes): + id=node.name + name=node.name + inputs=[str(i) for i in list(node.all_input_nodes)] + output_dtype='' + output_shape='' + if 'tensor_meta' in node.meta: + if type(node.meta['tensor_meta']) is tuple or type(node.meta['tensor_meta']) is list: + for tensor_meta in node.meta['tensor_meta']: + output_dtype+=str(tensor_meta.dtype)+' ' + output_shape+=','.join([str(i) for i in list(tensor_meta.shape)[1:]])+' ' + output_dtype=output_dtype[:-1] + output_shape=output_shape[:-1] + else: + output_dtype=str(node.meta['tensor_meta'].dtype) + output_shape=','.join([str(i) for i in list(node.meta['tensor_meta'].shape)[1:]]) + + if node.op == 'placeholder': + node_type='input' + op_module='' + op='' + params='' + elif node.op == 'call_module': + node_type='module' + name=re.sub('\.([0-9]+)', r'[\1]', node.target) + op_module=modules[node.target].__module__ + op_module='torch.nn' #prefer this shortcut for all modules + op=modules[node.target].__class__.__name__ + params=modules[node.target].extra_repr() + elif node.op == 'call_function': + node_type='function' + op_module=node.target.__module__ if node.target.__module__ is not None else "torch" + op_module='operator' if op_module=='_operator' else op_module + op=node.target.__name__ + params_list=[str(x) for x in node.args] + params_list.extend([f'{key}={value}' for key, value in node._kwargs.items()]) + params=', '.join(params_list) + elif node.op == 'call_method': + node_type='function' + op_module="torch" + op=node.target + params_list=[str(x) for x in node.args] + params_list.extend([f'{key}={value}' for key, value in node._kwargs.items()]) + params=', '.join(params_list) + elif node.op == 'output': + node_type='output' + op_module='' + op='' + params='' + for input_node in node.all_input_nodes: + input_op="" + if input_node.op == 'call_module': + input_op=modules[input_node.target].__class__.__name__ + elif input_node.op == 'call_function': + input_op=input_node.target.__name__ + else: + node_type='unknown' + op_module='' + op='' + params='' + + if node_type=='output' and len(inputs)>1: + for i, input in enumerate(inputs): + parsed_nodes[id+"_"+str(i)]={'name':name+"["+str(i)+"]", 'type':node_type, 'op_module':op_module, 'op':op, 'params':params, 'output_dtype':output_dtype, 'output_shape':output_shape, 'inputs':inputs} + else: + parsed_nodes[id]={'name':name, 'type':node_type, 'op_module':op_module, 'op':op, 'params':params, 'output_dtype':output_dtype, 'output_shape':output_shape, 'inputs':inputs} + tc.send_msg(app_socket, 'GraphData', bytes(str(parsed_nodes),'utf-8')) + level+=1 + tc.send_msg(app_socket, 'GraphDataEnd') + + print("Model built ("+format(sum(p.numel() for p in model_env['model'].parameters() if p.requires_grad), ',d')+" parameters)") #from https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=robin_lobel + + #suggest loss names + loss=[] + for i, tensor in enumerate(output_tensors): + if len(tensor.shape)==1 and "int" in str(tensor.dtype): + #multiclass crossentropy classification + loss.append("CrossEntropy") + elif len(tensor.shape)==2 and tensor.shape[1]==len(labels): + #multiclass multilabel classification + loss.append("BinaryCrossEntropy") + else: + #default back to MSE for everything else + loss.append("MeanSquareError") + + #suggest metric names + metric=[] + for tensor in output_tensors: + metric.append("Accuracy") + + tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([128,0,100,1,1])) + tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['Adam','Step'])) + + if msg_type == 'Exit': + break + diff --git a/torchstudio/models/unet1d.py b/torchstudio/models/unet1d.py index ba778c0..78a0baa 100644 --- a/torchstudio/models/unet1d.py +++ b/torchstudio/models/unet1d.py @@ -1,143 +1,143 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py -def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False): - sequence = [] - for i in range(conv_per_block): - sequence.append(nn.Conv1d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2)) - sequence.append(nn.ReLU(inplace=True)) - if batch_norm: - #BatchNorm best after ReLU: - #https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/ - #https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999 - #https://github.com/cvjena/cnn-models/issues/3 - sequence.append(nn.BatchNorm1d(out_channels)) - return nn.Sequential(*sequence) - -class DownConv(nn.Module): - def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True): - super().__init__() - - self.pooling = pooling - - self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm) - - if self.pooling: - if not conv_downscaling: - self.pool = nn.MaxPool1d(kernel_size=2, stride=2) - else: - self.pool = nn.Conv1d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2) - - def forward(self, x): - x = self.block(x) - before_pool = x - if self.pooling: - x = self.pool(x) - return x, before_pool - - -class UpConv(nn.Module): - def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, - add_merging, conv_upscaling): - super().__init__() - - self.add_merging = add_merging - - if not conv_upscaling: - self.upconv = nn.ConvTranspose1d(in_channels,out_channels,kernel_size=2,stride=2) - else: - self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2), - nn.Conv1d(in_channels, out_channels,kernel_size=1,groups=1,stride=1)) - - - self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm) - - def forward(self, from_down, from_up): - from_up = self.upconv(from_up) - if not self.add_merging: - x = torch.cat((from_up, from_down), 1) - else: - x = from_up + from_down - x = self.block(x) - return x - - -class UNet1D(nn.Module): - """ `UNet` class is based on https://arxiv.org/abs/1505.04597 - UNet is a convolutional encoder-decoder neural network. - - This 1D variant is inspired by 1D Unet are inspired by the - Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf ) - Default parameters correspond to the Wave UNet. - Convolutions use padding to preserve the original size. - - Args: - in_channels: number of channels in the input tensor. - out_channels: number of channels in the output tensor. - feature_channels: number of channels in the first and last hidden feature layer. - depth: number of levels - conv_per_block: number of convolutions per level block - kernel_size: kernel size for all block convolutions - batch_norm: add a batch norm after ReLU - conv_upscaling: use a nearest upsize+conv instead of transposed convolution - conv_downscaling: use a strided convolution instead of maxpooling - add_merging: merge layers from different levels using a add instead of a concat - """ - - def __init__(self, in_channels=1, out_channels=1, feature_channels=24, - depth=12, conv_per_block=1, kernel_size=5, batch_norm=False, - conv_upscaling=False, conv_downscaling=False, add_merging=False): - super().__init__() - - self.out_channels = out_channels - self.in_channels = in_channels - self.feature_channels = feature_channels - self.depth = depth - - self.down_convs = [] - self.up_convs = [] - - # create the encoder pathway and add to a list - for i in range(depth): - ins = self.in_channels if i == 0 else outs - outs = self.feature_channels*(i+1) - pooling = True if i < depth-1 else False - - down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm, - conv_downscaling, pooling=pooling) - self.down_convs.append(down_conv) - - # create the decoder pathway and add to a list - # - careful! decoding only requires depth-1 blocks - for i in range(depth-1): - ins = outs - outs = ins - self.feature_channels - up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm, - conv_upscaling=conv_upscaling, add_merging=add_merging) - self.up_convs.append(up_conv) - - self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1) - - # add the list of modules to current module - self.down_convs = nn.ModuleList(self.down_convs) - self.up_convs = nn.ModuleList(self.up_convs) - - def forward(self, x): - encoder_outs = [] - - # encoder pathway, save outputs for merging - for i, module in enumerate(self.down_convs): - x, before_pool = module(x) - encoder_outs.append(before_pool) - - for i, module in enumerate(self.up_convs): - before_pool = encoder_outs[-(i+2)] - x = module(before_pool, x) - - # No softmax is used. This means you need to use - # nn.CrossEntropyLoss is your training script, - # as this module includes a softmax already. - x = self.conv_final(x) - return x +import torch +import torch.nn as nn +import torch.nn.functional as F + +#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py +def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False): + sequence = [] + for i in range(conv_per_block): + sequence.append(nn.Conv1d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2)) + sequence.append(nn.ReLU(inplace=True)) + if batch_norm: + #BatchNorm best after ReLU: + #https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/ + #https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999 + #https://github.com/cvjena/cnn-models/issues/3 + sequence.append(nn.BatchNorm1d(out_channels)) + return nn.Sequential(*sequence) + +class DownConv(nn.Module): + def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True): + super().__init__() + + self.pooling = pooling + + self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm) + + if self.pooling: + if not conv_downscaling: + self.pool = nn.MaxPool1d(kernel_size=2, stride=2) + else: + self.pool = nn.Conv1d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2) + + def forward(self, x): + x = self.block(x) + before_pool = x + if self.pooling: + x = self.pool(x) + return x, before_pool + + +class UpConv(nn.Module): + def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, + add_merging, conv_upscaling): + super().__init__() + + self.add_merging = add_merging + + if not conv_upscaling: + self.upconv = nn.ConvTranspose1d(in_channels,out_channels,kernel_size=2,stride=2) + else: + self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2), + nn.Conv1d(in_channels, out_channels,kernel_size=1,groups=1,stride=1)) + + + self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm) + + def forward(self, from_down, from_up): + from_up = self.upconv(from_up) + if not self.add_merging: + x = torch.cat((from_up, from_down), 1) + else: + x = from_up + from_down + x = self.block(x) + return x + + +class UNet1D(nn.Module): + """ `UNet` class is based on https://arxiv.org/abs/1505.04597 + UNet is a convolutional encoder-decoder neural network. + + This 1D variant is inspired by 1D Unet are inspired by the + Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf ) + Default parameters correspond to the Wave UNet. + Convolutions use padding to preserve the original size. + + Args: + in_channels: number of channels in the input tensor. + out_channels: number of channels in the output tensor. + feature_channels: number of channels in the first and last hidden feature layer. + depth: number of levels + conv_per_block: number of convolutions per level block + kernel_size: kernel size for all block convolutions + batch_norm: add a batch norm after ReLU + conv_upscaling: use a nearest upsize+conv instead of transposed convolution + conv_downscaling: use a strided convolution instead of maxpooling + add_merging: merge layers from different levels using a add instead of a concat + """ + + def __init__(self, in_channels=1, out_channels=1, feature_channels=24, + depth=12, conv_per_block=1, kernel_size=5, batch_norm=False, + conv_upscaling=False, conv_downscaling=False, add_merging=False): + super().__init__() + + self.out_channels = out_channels + self.in_channels = in_channels + self.feature_channels = feature_channels + self.depth = depth + + self.down_convs = [] + self.up_convs = [] + + # create the encoder pathway and add to a list + for i in range(depth): + ins = self.in_channels if i == 0 else outs + outs = self.feature_channels*(i+1) + pooling = True if i < depth-1 else False + + down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm, + conv_downscaling, pooling=pooling) + self.down_convs.append(down_conv) + + # create the decoder pathway and add to a list + # - careful! decoding only requires depth-1 blocks + for i in range(depth-1): + ins = outs + outs = ins - self.feature_channels + up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm, + conv_upscaling=conv_upscaling, add_merging=add_merging) + self.up_convs.append(up_conv) + + self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1) + + # add the list of modules to current module + self.down_convs = nn.ModuleList(self.down_convs) + self.up_convs = nn.ModuleList(self.up_convs) + + def forward(self, x): + encoder_outs = [] + + # encoder pathway, save outputs for merging + for i, module in enumerate(self.down_convs): + x, before_pool = module(x) + encoder_outs.append(before_pool) + + for i, module in enumerate(self.up_convs): + before_pool = encoder_outs[-(i+2)] + x = module(before_pool, x) + + # No softmax is used. This means you need to use + # nn.CrossEntropyLoss is your training script, + # as this module includes a softmax already. + x = self.conv_final(x) + return x diff --git a/torchstudio/models/unet2d.py b/torchstudio/models/unet2d.py index ca1c91e..1b58ae0 100644 --- a/torchstudio/models/unet2d.py +++ b/torchstudio/models/unet2d.py @@ -1,166 +1,166 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py -def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False): - sequence = [] - for i in range(conv_per_block): - sequence.append(nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2)) - sequence.append(nn.ReLU(inplace=True)) - if batch_norm: - #BatchNorm best after ReLU: - #https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/ - #https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999 - #https://github.com/cvjena/cnn-models/issues/3 - sequence.append(nn.BatchNorm2d(out_channels)) - return nn.Sequential(*sequence) - -class DownConv(nn.Module): - def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True): - super().__init__() - - self.in_channels=in_channels - self.out_channels=out_channels - self.conv_per_block=conv_per_block - self.kernel_size=kernel_size - self.batch_norm=batch_norm - self.conv_downscaling=conv_downscaling - self.pooling = pooling - - self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm) - - if self.pooling: - if not conv_downscaling: - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - else: - self.pool = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2) - - def forward(self, x): - x = self.block(x) - before_pool = x - if self.pooling: - x = self.pool(x) - return x, before_pool - - def extra_repr(self): - # (Optional)Set the extra information about this module. You can test - # it by printing an object of this class. - return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, conv_downscaling={}, pooling={}'.format( - self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.conv_downscaling, self.pooling - ) - - -class UpConv(nn.Module): - def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, - add_merging, conv_upscaling): - super().__init__() - - self.in_channels=in_channels - self.out_channels=out_channels - self.conv_per_block=conv_per_block - self.kernel_size=kernel_size - self.batch_norm=batch_norm - self.add_merging = add_merging - self.conv_upscaling = conv_upscaling - - if not conv_upscaling: - self.upconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2) - else: - self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2), - nn.Conv2d(in_channels, out_channels,kernel_size=1,groups=1,stride=1)) - - - self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm) - - def forward(self, from_down, from_up): - from_up = self.upconv(from_up) - if not self.add_merging: - x = torch.cat((from_up, from_down), 1) - else: - x = from_up + from_down - x = self.block(x) - return x - - def extra_repr(self): - # (Optional)Set the extra information about this module. You can test - # it by printing an object of this class. - return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, add_merging={}, conv_upscaling={}'.format( - self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.add_merging, self.conv_upscaling - ) - -class UNet2D(nn.Module): - """ `UNet` class is based on https://arxiv.org/abs/1505.04597 - UNet is a convolutional encoder-decoder neural network. - - Default parameters correspond to the original UNet, except - convolutions use padding to preserve the original size. - - Args: - in_channels: number of channels in the input tensor. - out_channels: number of channels in the output tensor. - feature_channels: number of channels in the first and last hidden feature layer. - depth: number of levels - conv_per_block: number of convolutions per level block - kernel_size: kernel size for all block convolutions - batch_norm: add a batch norm after ReLU - conv_upscaling: use a nearest upscale+conv instead of transposed convolution - conv_downscaling: use a strided convolution instead of maxpooling - add_merging: merge layers from different levels using a add instead of a concat - """ - - def __init__(self, in_channels=1, out_channels=2, feature_channels=64, - depth=5, conv_per_block=2, kernel_size=3, batch_norm=False, - conv_upscaling=False, conv_downscaling=False, add_merging=False): - super().__init__() - - self.out_channels = out_channels - self.in_channels = in_channels - self.feature_channels = feature_channels - self.depth = depth - - self.down_convs = [] - self.up_convs = [] - - # create the encoder pathway and add to a list - for i in range(depth): - ins = self.in_channels if i == 0 else outs - outs = self.feature_channels*(2**i) - pooling = True if i < depth-1 else False - - down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm, - conv_downscaling, pooling=pooling) - self.down_convs.append(down_conv) - - # create the decoder pathway and add to a list - # - careful! decoding only requires depth-1 blocks - for i in range(depth-1): - ins = outs - outs = ins // 2 - up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm, - conv_upscaling=conv_upscaling, add_merging=add_merging) - self.up_convs.append(up_conv) - - self.conv_final = nn.Conv2d(outs, self.out_channels,kernel_size=1,groups=1,stride=1) - - # add the list of modules to current module - self.down_convs = nn.ModuleList(self.down_convs) - self.up_convs = nn.ModuleList(self.up_convs) - - def forward(self, x): - encoder_outs = [] - - # encoder pathway, save outputs for merging - for i, module in enumerate(self.down_convs): - x, before_pool = module(x) - encoder_outs.append(before_pool) - - for i, module in enumerate(self.up_convs): - before_pool = encoder_outs[-(i+2)] - x = module(before_pool, x) - - # No softmax is used. This means you need to use - # nn.CrossEntropyLoss is your training script, - # as this module includes a softmax already. - x = self.conv_final(x) - return x +import torch +import torch.nn as nn +import torch.nn.functional as F + +#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py +def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False): + sequence = [] + for i in range(conv_per_block): + sequence.append(nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2)) + sequence.append(nn.ReLU(inplace=True)) + if batch_norm: + #BatchNorm best after ReLU: + #https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/ + #https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999 + #https://github.com/cvjena/cnn-models/issues/3 + sequence.append(nn.BatchNorm2d(out_channels)) + return nn.Sequential(*sequence) + +class DownConv(nn.Module): + def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True): + super().__init__() + + self.in_channels=in_channels + self.out_channels=out_channels + self.conv_per_block=conv_per_block + self.kernel_size=kernel_size + self.batch_norm=batch_norm + self.conv_downscaling=conv_downscaling + self.pooling = pooling + + self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm) + + if self.pooling: + if not conv_downscaling: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + else: + self.pool = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2) + + def forward(self, x): + x = self.block(x) + before_pool = x + if self.pooling: + x = self.pool(x) + return x, before_pool + + def extra_repr(self): + # (Optional)Set the extra information about this module. You can test + # it by printing an object of this class. + return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, conv_downscaling={}, pooling={}'.format( + self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.conv_downscaling, self.pooling + ) + + +class UpConv(nn.Module): + def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, + add_merging, conv_upscaling): + super().__init__() + + self.in_channels=in_channels + self.out_channels=out_channels + self.conv_per_block=conv_per_block + self.kernel_size=kernel_size + self.batch_norm=batch_norm + self.add_merging = add_merging + self.conv_upscaling = conv_upscaling + + if not conv_upscaling: + self.upconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2) + else: + self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2), + nn.Conv2d(in_channels, out_channels,kernel_size=1,groups=1,stride=1)) + + + self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm) + + def forward(self, from_down, from_up): + from_up = self.upconv(from_up) + if not self.add_merging: + x = torch.cat((from_up, from_down), 1) + else: + x = from_up + from_down + x = self.block(x) + return x + + def extra_repr(self): + # (Optional)Set the extra information about this module. You can test + # it by printing an object of this class. + return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, add_merging={}, conv_upscaling={}'.format( + self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.add_merging, self.conv_upscaling + ) + +class UNet2D(nn.Module): + """ `UNet` class is based on https://arxiv.org/abs/1505.04597 + UNet is a convolutional encoder-decoder neural network. + + Default parameters correspond to the original UNet, except + convolutions use padding to preserve the original size. + + Args: + in_channels: number of channels in the input tensor. + out_channels: number of channels in the output tensor. + feature_channels: number of channels in the first and last hidden feature layer. + depth: number of levels + conv_per_block: number of convolutions per level block + kernel_size: kernel size for all block convolutions + batch_norm: add a batch norm after ReLU + conv_upscaling: use a nearest upscale+conv instead of transposed convolution + conv_downscaling: use a strided convolution instead of maxpooling + add_merging: merge layers from different levels using a add instead of a concat + """ + + def __init__(self, in_channels=1, out_channels=2, feature_channels=64, + depth=5, conv_per_block=2, kernel_size=3, batch_norm=False, + conv_upscaling=False, conv_downscaling=False, add_merging=False): + super().__init__() + + self.out_channels = out_channels + self.in_channels = in_channels + self.feature_channels = feature_channels + self.depth = depth + + self.down_convs = [] + self.up_convs = [] + + # create the encoder pathway and add to a list + for i in range(depth): + ins = self.in_channels if i == 0 else outs + outs = self.feature_channels*(2**i) + pooling = True if i < depth-1 else False + + down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm, + conv_downscaling, pooling=pooling) + self.down_convs.append(down_conv) + + # create the decoder pathway and add to a list + # - careful! decoding only requires depth-1 blocks + for i in range(depth-1): + ins = outs + outs = ins // 2 + up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm, + conv_upscaling=conv_upscaling, add_merging=add_merging) + self.up_convs.append(up_conv) + + self.conv_final = nn.Conv2d(outs, self.out_channels,kernel_size=1,groups=1,stride=1) + + # add the list of modules to current module + self.down_convs = nn.ModuleList(self.down_convs) + self.up_convs = nn.ModuleList(self.up_convs) + + def forward(self, x): + encoder_outs = [] + + # encoder pathway, save outputs for merging + for i, module in enumerate(self.down_convs): + x, before_pool = module(x) + encoder_outs.append(before_pool) + + for i, module in enumerate(self.up_convs): + before_pool = encoder_outs[-(i+2)] + x = module(before_pool, x) + + # No softmax is used. This means you need to use + # nn.CrossEntropyLoss is your training script, + # as this module includes a softmax already. + x = self.conv_final(x) + return x diff --git a/torchstudio/modeltrain.py b/torchstudio/modeltrain.py index 3b720b8..2bf6558 100644 --- a/torchstudio/modeltrain.py +++ b/torchstudio/modeltrain.py @@ -1,360 +1,364 @@ -#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490 -import os -os.environ['KMP_DUPLICATE_LIB_OK']='True' - -import sys - -print("Loading PyTorch...\n", file=sys.stderr) - -import torch -from torch.utils.data import Dataset -import torchstudio.tcpcodec as tc -from torchstudio.modules import safe_exec -import os -import sys -import io -import tempfile -from tqdm.auto import tqdm -from collections.abc import Iterable - - -class CachedDataset(Dataset): - def __init__(self, disk_cache=False): - self.reset(disk_cache) - - def add_sample(self, data): - if self.disk_cache: - file=tempfile.TemporaryFile(prefix='torchstudio.'+str(len(self.index))+'.') #guaranteed to be deleted on win/mac/linux: https://bugs.python.org/issue4928 - file.write(data) - self.index.append(file) - else: - self.index.append(tc.decode_torch_tensors(data)) - - def reset(self, disk_cache=False): - self.index = [] - self.disk_cache=disk_cache - - def __len__(self): - return len(self.index) - - def __getitem__(self, id): - if id<0 or id>=len(self): - raise IndexError - - if self.disk_cache: - file=self.index[id] - file.seek(0) - sample=tc.decode_torch_tensors(file.read()) - else: - sample=self.index[id] - return sample - -def deepcopy_cpu(value): - if isinstance(value, torch.Tensor): - value = value.to("cpu") - return value - elif isinstance(value, dict): - return {k: deepcopy_cpu(v) for k, v in value.items()} - elif isinstance(value, Iterable): - return type(value)(deepcopy_cpu(v) for v in value) - else: - return value - -modules_valid=True - -train_dataset = CachedDataset() -valid_dataset = CachedDataset() -train_bar = None - -model = None - -app_socket = tc.connect() -print("Training script connected\n", file=sys.stderr) -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - - if msg_type == 'SetDevice': - print("Setting device...\n", file=sys.stderr) - device_id=tc.decode_strings(msg_data)[0] - device = torch.device(device_id) - pin_memory = True if 'cuda' in device_id else False - - if msg_type == 'SetTorchScriptModel' and modules_valid: - if msg_data: - print("Setting torchscript model...\n", file=sys.stderr) - buffer=io.BytesIO(msg_data) - model = torch.jit.load(buffer) - - if msg_type == 'SetPackageModel' and modules_valid: - if msg_data: - print("Setting package model...\n", file=sys.stderr) - buffer=io.BytesIO(msg_data) - model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl') - - if msg_type == 'SetModelState' and modules_valid: - if model is not None: - if msg_data: - buffer=io.BytesIO(msg_data) - model.load_state_dict(torch.load(buffer)) - model.to(device) - - if msg_type == 'SetLossCodes' and modules_valid: - print("Setting loss code...\n", file=sys.stderr) - loss_definitions=tc.decode_strings(msg_data) - criteria = [] - for definition in loss_definitions: - error_msg, loss_env = safe_exec(definition, description='loss definition') - if error_msg is not None or 'loss' not in loss_env: - print("Unknown loss definition error" if error_msg is None else error_msg, file=sys.stderr) - modules_valid=False - tc.send_msg(app_socket, 'TrainingError') - break - else: - criteria.append(loss_env['loss']) - - if msg_type == 'SetMetricCodes' and modules_valid: - print("Setting metrics code...\n", file=sys.stderr) - metric_definitions=tc.decode_strings(msg_data) - metrics = [] - for definition in metric_definitions: - error_msg, metric_env = safe_exec(definition, description='metric definition') - if error_msg is not None or 'metric' not in metric_env: - print("Unknown metric definition error" if error_msg is None else error_msg, file=sys.stderr) - modules_valid=False - tc.send_msg(app_socket, 'TrainingError') - break - else: - metrics.append(metric_env['metric']) - - if msg_type == 'SetOptimizerCode' and modules_valid: - print("Setting optimizer code...\n", file=sys.stderr) - error_msg, optimizer_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='optimizer definition') - if error_msg is not None or 'optimizer' not in optimizer_env: - print("Unknown optimizer definition error" if error_msg is None else error_msg, file=sys.stderr) - modules_valid=False - tc.send_msg(app_socket, 'TrainingError') - else: - optimizer = optimizer_env['optimizer'] - - if msg_type == 'SetOptimizerState' and modules_valid: - if msg_data: - buffer=io.BytesIO(msg_data) - optimizer.load_state_dict(torch.load(buffer)) - - if msg_type == 'SetSchedulerCode' and modules_valid: - print("Setting scheduler code...\n", file=sys.stderr) - error_msg, scheduler_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='scheduler definition') - if error_msg is not None or 'scheduler' not in scheduler_env: - print("Unknown scheduler definition error" if error_msg is None else error_msg, file=sys.stderr) - modules_valid=False - tc.send_msg(app_socket, 'TrainingError') - else: - scheduler = scheduler_env['scheduler'] - - if msg_type == 'SetHyperParametersValues' and modules_valid: #set other hyperparameters values - batch_size, shuffle, epochs, early_stop, monitor_metric, restore_best = tc.decode_ints(msg_data) - shuffle=True if shuffle==1 else False - early_stop=True if early_stop==1 else False - monitor_metric=True if monitor_metric==1 else False - restore_best=True if restore_best==1 else False - - if msg_type == 'StartTrainingServer' and modules_valid: - print("Caching...\n", file=sys.stderr) - - sshaddress, sshport, username, password, keydata = tc.decode_strings(msg_data) - - training_server, address = tc.generate_server() - - if sshaddress and sshport and username: - import socket - import paramiko - import torchstudio.sshtunnel as sshtunnel - - if not password: - password=None - if not keydata: - pkey=None - else: - import io - keybuffer=io.StringIO(keydata) - pkey=paramiko.RSAKey.from_private_key(keybuffer) - - sshclient = paramiko.SSHClient() - sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5) - - reverse_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ReverseTunnel, 'localhost', 0, 'localhost', int(address[1])) - address[1]=str(reverse_tunnel.lport) - - tc.send_msg(app_socket, 'TrainingServerRequestingAllSamples', tc.encode_strings(address)) - dataset_socket=tc.start_server(training_server) - train_dataset.reset() - valid_dataset.reset() - - while True: - dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket) - - if dataset_msg_type == 'NumSamples': - num_samples=tc.decode_ints(dataset_msg_data)[0] - pbar=tqdm(total=num_samples, desc='Caching...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters - - if dataset_msg_type == 'InputTensorsID' and modules_valid: - input_tensors_id = tc.decode_ints(dataset_msg_data) - - if dataset_msg_type == 'OutputTensorsID' and modules_valid: - output_tensors_id = tc.decode_ints(dataset_msg_data) - - if dataset_msg_type == 'TrainingSample': - train_dataset.add_sample(dataset_msg_data) - pbar.update(1) - - if dataset_msg_type == 'ValidationSample': - valid_dataset.add_sample(dataset_msg_data) - pbar.update(1) - - if dataset_msg_type == 'DoneSending': - pbar.close() - tc.send_msg(dataset_socket, 'DoneReceiving') - dataset_socket.close() - training_server.close() - if sshaddress and sshport and username: - sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong - break - - train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory) - valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size, shuffle=False, pin_memory=pin_memory) - - if msg_type == 'StartTraining' and modules_valid: - print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr) - - if msg_type == 'TrainOneEpoch' and modules_valid: - #training - model.train() - train_loss = 0 - train_metrics = [] - for metric in metrics: - metric.reset() - for batch_id, tensors in enumerate(train_loader): - inputs = [tensors[i].to(device) for i in input_tensors_id] - targets = [tensors[i].to(device) for i in output_tensors_id] - optimizer.zero_grad() - outputs = model(*inputs) - outputs = outputs if type(outputs) is not torch.Tensor else [outputs] - loss = 0 - for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440 - loss = loss + criterion(output, target) - loss.backward() - optimizer.step() - train_loss += loss.item() * inputs[0].size(0) - - with torch.set_grad_enabled(False): - for output, target, metric in zip(outputs, targets, metrics): - metric.update(output, target) - - train_loss = train_loss/len(train_dataset) - train_metrics = 0 - for metric in metrics: - train_metrics = train_metrics+metric.compute().item() - train_metrics/=len(metrics) - scheduler.step() - - #validation - model.eval() - valid_loss = 0 - valid_metrics = [] - for metric in metrics: - metric.reset() - with torch.set_grad_enabled(False): - for batch_id, tensors in enumerate(valid_loader): - inputs = [tensors[i].to(device) for i in input_tensors_id] - targets = [tensors[i].to(device) for i in output_tensors_id] - outputs = model(*inputs) - outputs = outputs if type(outputs) is not torch.Tensor else [outputs] - loss = 0 - for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440 - loss = loss + criterion(output, target) - valid_loss += loss.item() * inputs[0].size(0) - - for output, target, metric in zip(outputs, targets, metrics): - metric.update(output, target) - - valid_loss = valid_loss/len(valid_dataset) - valid_metrics = 0 - for metric in metrics: - valid_metrics = valid_metrics+metric.compute().item() - valid_metrics/=len(metrics) - - tc.send_msg(app_socket, 'TrainingLoss', tc.encode_floats(train_loss)) - tc.send_msg(app_socket, 'ValidationLoss', tc.encode_floats(valid_loss)) - tc.send_msg(app_socket, 'TrainingMetric', tc.encode_floats(train_metrics)) - tc.send_msg(app_socket, 'ValidationMetric', tc.encode_floats(valid_metrics)) - - buffer=io.BytesIO() - torch.save(deepcopy_cpu(model.state_dict()), buffer) - tc.send_msg(app_socket, 'ModelState', buffer.getvalue()) - - buffer=io.BytesIO() - torch.save(deepcopy_cpu(optimizer.state_dict()), buffer) - tc.send_msg(app_socket, 'OptimizerState', buffer.getvalue()) - - tc.send_msg(app_socket, 'Trained') - - #create train_bar only after first successful training to avoid ghost progress message after an error - if train_bar is not None: - train_bar.bar_format='{desc} epoch {n_fmt} | {remaining} left |{rate_fmt}\n\n' - else: - train_bar = tqdm(total=epochs, desc='Training...', bar_format='{desc} epoch '+str(scheduler.last_epoch)+'\n\n', initial=scheduler.last_epoch-1) - train_bar.update(1) - - if msg_type == 'StopTraining' and modules_valid: - if train_bar is not None: - train_bar.close() - train_bar=None - print("Training stopped at epoch "+str(scheduler.last_epoch-1), file=sys.stderr) - - if msg_type == 'SetInputTensors' or msg_type == 'InferTensors': - input_tensors = tc.decode_torch_tensors(msg_data) - for i, tensor in enumerate(input_tensors): - input_tensors[i]=torch.unsqueeze(tensor, 0).to(device) #add batch dimension - - if msg_type == 'InferTensors': - if model is not None: - with torch.set_grad_enabled(False): - model.eval() - output_tensors=model(*input_tensors) - output_tensors=[output.cpu() for output in output_tensors] - tc.send_msg(app_socket, 'InferedTensors', tc.encode_torch_tensors(output_tensors)) - - if msg_type == 'SaveTorchScript': - path, mode = tc.decode_strings(msg_data) - if "torch.jit" in str(type(model)): - torch.jit.save(model, path) #already a torchscript, save as is - print("Export complete") - else: - if mode=="trace": - error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing') - else: - error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':model}, description='model scripting') - if error_msg: - print("Error exporting:", error_msg, file=sys.stderr) - else: - torch.jit.save(torchscript_model, path) - print("Export complete") - - if msg_type == 'SaveONNX': - error_msg=None - torchscript_model=model - if not "torch.jit" in str(type(model)): - error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing') - if error_msg: - print("Error exporting:", error_msg, file=sys.stderr) - else: - error_msg, torchscript_model = safe_exec(torch.onnx.export,{'model':torchscript_model, 'args':input_tensors, 'f':tc.decode_strings(msg_data)[0], 'opset_version':12}) - if error_msg: - print("Error exporting:", error_msg, file=sys.stderr) - else: - print("Export complete") - - if msg_type == 'Exit': - break - +#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490 +import os +os.environ['KMP_DUPLICATE_LIB_OK']='True' + +import sys + +print("Loading PyTorch...\n", file=sys.stderr) + +import torch +from torch.utils.data import Dataset +import torchstudio.tcpcodec as tc +from torchstudio.modules import safe_exec +import os +import sys +import io +import tempfile +from tqdm.auto import tqdm +from collections.abc import Iterable + + +class CachedDataset(Dataset): + def __init__(self, disk_cache=False): + self.reset(disk_cache) + + def add_sample(self, data): + if self.disk_cache: + file=tempfile.TemporaryFile(prefix='torchstudio.'+str(len(self.index))+'.') #guaranteed to be deleted on win/mac/linux: https://bugs.python.org/issue4928 + file.write(data) + self.index.append(file) + else: + self.index.append(tc.decode_torch_tensors(data)) + + def reset(self, disk_cache=False): + self.index = [] + self.disk_cache=disk_cache + + def __len__(self): + return len(self.index) + + def __getitem__(self, id): + if id<0 or id>=len(self): + raise IndexError + + if self.disk_cache: + file=self.index[id] + file.seek(0) + sample=tc.decode_torch_tensors(file.read()) + else: + sample=self.index[id] + return sample + +def deepcopy_cpu(value): + if isinstance(value, torch.Tensor): + value = value.to("cpu") + return value + elif isinstance(value, dict): + return {k: deepcopy_cpu(v) for k, v in value.items()} + elif isinstance(value, Iterable): + return type(value)(deepcopy_cpu(v) for v in value) + else: + return value + +modules_valid=True + +train_dataset = CachedDataset() +valid_dataset = CachedDataset() +train_bar = None + +model = None + +app_socket = tc.connect() +print("Training script connected\n", file=sys.stderr) +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + + if msg_type == 'SetDevice': + print("Setting device...\n", file=sys.stderr) + device_id=tc.decode_strings(msg_data)[0] + device = torch.device(device_id) + pin_memory = True if 'cuda' in device_id else False + + if msg_type == 'SetTorchScriptModel' and modules_valid: + if msg_data: + print("Setting torchscript model...\n", file=sys.stderr) + buffer=io.BytesIO(msg_data) + model = torch.jit.load(buffer) + + if msg_type == 'SetPackageModel' and modules_valid: + if msg_data: + print("Setting package model...\n", file=sys.stderr) + buffer=io.BytesIO(msg_data) + model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl') + + if msg_type == 'SetModelState' and modules_valid: + if model is not None: + if msg_data: + buffer=io.BytesIO(msg_data) + model.load_state_dict(torch.load(buffer)) + model.to(device) + + if msg_type == 'SetLossCodes' and modules_valid: + print("Setting loss code...\n", file=sys.stderr) + loss_definitions=tc.decode_strings(msg_data) + criteria = [] + for definition in loss_definitions: + error_msg, loss_env = safe_exec(definition, description='loss definition') + if error_msg is not None or 'loss' not in loss_env: + print("Unknown loss definition error" if error_msg is None else error_msg, file=sys.stderr) + modules_valid=False + tc.send_msg(app_socket, 'TrainingError') + break + else: + criteria.append(loss_env['loss']) + + if msg_type == 'SetMetricCodes' and modules_valid: + print("Setting metrics code...\n", file=sys.stderr) + metric_definitions=tc.decode_strings(msg_data) + metrics = [] + for definition in metric_definitions: + error_msg, metric_env = safe_exec(definition, description='metric definition') + if error_msg is not None or 'metric' not in metric_env: + print("Unknown metric definition error" if error_msg is None else error_msg, file=sys.stderr) + modules_valid=False + tc.send_msg(app_socket, 'TrainingError') + break + else: + metrics.append(metric_env['metric']) + + if msg_type == 'SetOptimizerCode' and modules_valid: + print("Setting optimizer code...\n", file=sys.stderr) + error_msg, optimizer_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='optimizer definition') + if error_msg is not None or 'optimizer' not in optimizer_env: + print("Unknown optimizer definition error" if error_msg is None else error_msg, file=sys.stderr) + modules_valid=False + tc.send_msg(app_socket, 'TrainingError') + else: + optimizer = optimizer_env['optimizer'] + + if msg_type == 'SetOptimizerState' and modules_valid: + if msg_data: + buffer=io.BytesIO(msg_data) + optimizer.load_state_dict(torch.load(buffer)) + + if msg_type == 'SetSchedulerCode' and modules_valid: + print("Setting scheduler code...\n", file=sys.stderr) + error_msg, scheduler_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='scheduler definition') + if error_msg is not None or 'scheduler' not in scheduler_env: + print("Unknown scheduler definition error" if error_msg is None else error_msg, file=sys.stderr) + modules_valid=False + tc.send_msg(app_socket, 'TrainingError') + else: + scheduler = scheduler_env['scheduler'] + + if msg_type == 'SetHyperParametersValues' and modules_valid: #set other hyperparameters values + batch_size, shuffle, epochs, early_stop, restore_best = tc.decode_ints(msg_data) + shuffle=True if shuffle==1 else False + early_stop=True if early_stop==1 else False + restore_best=True if restore_best==1 else False + + if msg_type == 'StartTrainingServer' and modules_valid: + print("Caching...\n", file=sys.stderr) + + sshaddress, sshport, username, password, keydata = tc.decode_strings(msg_data) + + training_server, address = tc.generate_server() + + if sshaddress and sshport and username: + import socket + import paramiko + import torchstudio.sshtunnel as sshtunnel + + if not password: + password=None + if not keydata: + pkey=None + else: + import io + keybuffer=io.StringIO(keydata) + pkey=paramiko.RSAKey.from_private_key(keybuffer) + + sshclient = paramiko.SSHClient() + sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5) + + reverse_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ReverseTunnel, 'localhost', 0, 'localhost', int(address[1])) + address[1]=str(reverse_tunnel.lport) + + tc.send_msg(app_socket, 'TrainingServerRequestingAllSamples', tc.encode_strings(address)) + dataset_socket=tc.start_server(training_server) + train_dataset.reset() + valid_dataset.reset() + + while True: + dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket) + + if dataset_msg_type == 'NumSamples': + num_samples=tc.decode_ints(dataset_msg_data)[0] + pbar=tqdm(total=num_samples, desc='Caching...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters + + if dataset_msg_type == 'InputTensorsID' and modules_valid: + input_tensors_id = tc.decode_ints(dataset_msg_data) + + if dataset_msg_type == 'OutputTensorsID' and modules_valid: + output_tensors_id = tc.decode_ints(dataset_msg_data) + + if dataset_msg_type == 'TrainingSample': + train_dataset.add_sample(dataset_msg_data) + pbar.update(1) + + if dataset_msg_type == 'ValidationSample': + valid_dataset.add_sample(dataset_msg_data) + pbar.update(1) + + if dataset_msg_type == 'DoneSending': + pbar.close() + tc.send_msg(dataset_socket, 'DoneReceiving') + dataset_socket.close() + training_server.close() + if sshaddress and sshport and username: + sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong + break + + train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory) + valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size, shuffle=False, pin_memory=pin_memory) + + if msg_type == 'StartTraining' and modules_valid: + print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr) + + if msg_type == 'TrainOneEpoch' and modules_valid: + #training + model.train() + train_loss = 0 + train_metrics = [] + for metric in metrics: + metric.reset() + for batch_id, tensors in enumerate(train_loader): + inputs = [tensors[i].to(device) for i in input_tensors_id] + targets = [tensors[i].to(device) for i in output_tensors_id] + optimizer.zero_grad() + outputs = model(*inputs) + outputs = outputs if type(outputs) is not torch.Tensor else [outputs] + loss = 0 + for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440 + loss = loss + criterion(output, target) + loss.backward() + optimizer.step() + train_loss += loss.item() * inputs[0].size(0) + + with torch.set_grad_enabled(False): + for output, target, metric in zip(outputs, targets, metrics): + metric.update(output, target) + + train_loss = train_loss/len(train_dataset) + train_metrics = 0 + for metric in metrics: + train_metrics = train_metrics+metric.compute().item() + train_metrics/=len(metrics) + scheduler.step() + + #validation + model.eval() + valid_loss = 0 + valid_metrics = [] + for metric in metrics: + metric.reset() + with torch.set_grad_enabled(False): + for batch_id, tensors in enumerate(valid_loader): + inputs = [tensors[i].to(device) for i in input_tensors_id] + targets = [tensors[i].to(device) for i in output_tensors_id] + outputs = model(*inputs) + outputs = outputs if type(outputs) is not torch.Tensor else [outputs] + loss = 0 + for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440 + loss = loss + criterion(output, target) + valid_loss += loss.item() * inputs[0].size(0) + + for output, target, metric in zip(outputs, targets, metrics): + metric.update(output, target) + + valid_loss = valid_loss/len(valid_dataset) + valid_metrics = 0 + for metric in metrics: + valid_metrics = valid_metrics+metric.compute().item() + valid_metrics/=len(metrics) + + tc.send_msg(app_socket, 'TrainingLoss', tc.encode_floats(train_loss)) + tc.send_msg(app_socket, 'ValidationLoss', tc.encode_floats(valid_loss)) + tc.send_msg(app_socket, 'TrainingMetric', tc.encode_floats(train_metrics)) + tc.send_msg(app_socket, 'ValidationMetric', tc.encode_floats(valid_metrics)) + + buffer=io.BytesIO() + torch.save(deepcopy_cpu(model.state_dict()), buffer) + tc.send_msg(app_socket, 'ModelState', buffer.getvalue()) + + buffer=io.BytesIO() + torch.save(deepcopy_cpu(optimizer.state_dict()), buffer) + tc.send_msg(app_socket, 'OptimizerState', buffer.getvalue()) + + tc.send_msg(app_socket, 'Trained') + + #create train_bar only after first successful training to avoid ghost progress message after an error + if train_bar is not None: + train_bar.bar_format='{desc} epoch {n_fmt} |{rate_fmt}\n\n' + else: + train_bar = tqdm(total=epochs, desc='Training...', bar_format='{desc} epoch '+str(scheduler.last_epoch)+'\n\n', initial=scheduler.last_epoch-1) + train_bar.update(1) + + if msg_type == 'StopTraining' and modules_valid: + if train_bar is not None: + train_bar.close() + train_bar=None + print("Training stopped at epoch "+str(scheduler.last_epoch-1), file=sys.stderr) + + if msg_type == 'SetInputTensors' or msg_type == 'InferTensors': + input_tensors = tc.decode_torch_tensors(msg_data) + for i, tensor in enumerate(input_tensors): + input_tensors[i]=torch.unsqueeze(tensor, 0).to(device) #add batch dimension + + if msg_type == 'InferTensors': + if model is not None: + with torch.set_grad_enabled(False): + model.eval() + output_tensors=model(*input_tensors) + output_tensors=[output.cpu() for output in output_tensors] + tc.send_msg(app_socket, 'InferedTensors', tc.encode_torch_tensors(output_tensors)) + + if msg_type == 'SaveWeights': + path = tc.decode_strings(msg_data)[0] + torch.save(deepcopy_cpu(model.state_dict()), path) + print("Export complete") + + if msg_type == 'SaveTorchScript': + path, mode = tc.decode_strings(msg_data) + if "torch.jit" in str(type(model)): + torch.jit.save(model, path) #already a torchscript, save as is + print("Export complete") + else: + if mode=="trace": + error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing') + else: + error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':model}, description='model scripting') + if error_msg: + print("Error exporting:", error_msg, file=sys.stderr) + else: + torch.jit.save(torchscript_model, path) + print("Export complete") + + if msg_type == 'SaveONNX': + error_msg=None + torchscript_model=model + if not "torch.jit" in str(type(model)): + error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing') + if error_msg: + print("Error exporting:", error_msg, file=sys.stderr) + else: + error_msg, torchscript_model = safe_exec(torch.onnx.export,{'model':torchscript_model, 'args':input_tensors, 'f':tc.decode_strings(msg_data)[0], 'opset_version':12}) + if error_msg: + print("Error exporting:", error_msg, file=sys.stderr) + else: + print("Export complete") + + if msg_type == 'Exit': + break + diff --git a/torchstudio/parametersplot.py b/torchstudio/parametersplot.py index 70366ca..dd8930e 100644 --- a/torchstudio/parametersplot.py +++ b/torchstudio/parametersplot.py @@ -1,170 +1,170 @@ -import torchstudio.tcpcodec as tc -import inspect -import sys -import os - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.ticker import MaxNLocator -import PIL - -def sorted(l,reverse=False): - floats=True - for x in l: - try: - float(x) - except: - floats=False - break - l.sort(key=float if floats else None,reverse=reverse) - return l - -#inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib -def plot_parameters(size, dpi, - parameters=[], #parameters is a list of parameters - values=[], #values is a list of list containing string values - order=[]): #sorting order for each parameter(1 or -1) - """Parameters Plot - - Usage: - Click: invert parameter sorting order - """ - #set up matplotlib renderer, style, figure and axis - mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/ - plt.style.use('dark_background') - plt.rcParams.update({'font.size': 7}) - - if len(parameters)<2: - parameters=['Name', 'Validation\nMetric'] - -# parameters=['Name', 'feature_channels', 'depth', 'Metric Value'] -# values=[['Model 1','32','3','.95'],['Model 2','24','4','.9'],['Model 3','16','3','.98'],['Model 4','16','3']] - - if len(order)1 else .5 for j in range(len(param_values[i]))]) - ax.set_yticklabels(param_values[i]) - #first parameter is the model name, keep the set_ticks - axes[0].yaxis.set_tick_params(width=1) - #last parameter is the metric, let the colorbar do the metric - axes[-1].yaxis.set_ticks_position('none') - axes[-1].set_yticklabels([]) - axes[-1].spines['left'].set_visible(False) - - #set the colorbar for the metric - if param_values[-1]: - max_metric=min_metric=float(param_values[-1][0]) - for metric_value in param_values[-1]: - min_metric=min(min_metric,float(metric_value)) - max_metric=max(max_metric,float(metric_value)) - else: - max_metric=min_metric=0 - - cmap = plt.get_cmap('viridis') # 'viridis' or 'rainbow' - sc = host.scatter([0,0], [0,0], s=[0,0], c=[min_metric, max_metric], cmap=cmap) - cbar = fig.colorbar(sc, ax=axes[-1], pad=0) - cbar.outline.set_visible(False) -# cbar.set_ticks([]) - - #set horizontal axe settings - host.set_xlim(0, len(parameters) - 1) - host.set_xticks(range(len(parameters))) - host.set_xticklabels(parameters) - host.tick_params(axis='x', which='major', pad=7) - host.spines['right'].set_visible(False) - host.xaxis.tick_top() - - - - from matplotlib.path import Path - import matplotlib.patches as patches - import numpy as np - for tokens in values: - values_num=[] - for i, token in enumerate(tokens): - if i1 else .5) - else: - values_num.append((float(token)-min_metric)/(max_metric-min_metric) if len(param_values[i])>1 and max_metric>min_metric else .5) - - # create bezier curves - # for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one - # at one third towards the next axis; the first and last axis have one less control vertex - # x-coordinate of the control vertices: at each integer (for the axes) and two inbetween - # y-coordinate: repeat every point three times, except the first and last only twice - verts = list(zip([x for x in np.linspace(0, len(values_num) - 1, len(values_num) * 3 - 2, endpoint=True)], - np.repeat(values_num, 3)[1:-1])) - # for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers - codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)] - path = Path(verts, codes) - patch = patches.PathPatch(path, facecolor='none', lw=1, edgecolor=cmap(values_num[-1]) if len(values_num)==len(parameters) else (0.33, 0.33, 0.33), zorder=values_num[-1] if len(values_num)==len(parameters) else -1) - host.add_patch(patch) - - plt.tight_layout(pad=0) - - canvas = plt.get_current_fig_manager().canvas - canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) - plt.close() - return img - - -resolution = (256,256, 96) - -parameters=[] -values=[] -order=[] - - -app_socket = tc.connect() -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - - if msg_type == 'RequestDocumentation': - tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_parameters.__doc__))) - - if msg_type == 'SetResolution': - resolution = tc.decode_ints(msg_data) - - if msg_type == 'SetParameters': - parameters=tc.decode_strings(msg_data) - - if msg_type == 'ClearValues': - values = [] - if msg_type == 'AppendValues': - values.append(tc.decode_strings(msg_data)) - - if msg_type == 'SetOrder': - order=tc.decode_ints(msg_data) - - if msg_type == 'Render': - if resolution[0]>0 and resolution[1]>0: - img=plot_parameters(resolution[0:2],resolution[2],parameters,values,order) - tc.send_msg(app_socket, 'ImageData', tc.encode_image(img)) - - if msg_type == 'Exit': - break +import torchstudio.tcpcodec as tc +import inspect +import sys +import os + +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.ticker import MaxNLocator +import PIL + +def sorted(l,reverse=False): + floats=True + for x in l: + try: + float(x) + except: + floats=False + break + l.sort(key=float if floats else None,reverse=reverse) + return l + +#inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib +def plot_parameters(size, dpi, + parameters=[], #parameters is a list of parameters + values=[], #values is a list of list containing string values + order=[]): #sorting order for each parameter(1 or -1) + """Parameters Plot + + Usage: + Click: invert parameter sorting order + """ + #set up matplotlib renderer, style, figure and axis + mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/ + plt.style.use('dark_background') + plt.rcParams.update({'font.size': 7}) + + if len(parameters)<2: + parameters=['Name', 'Validation\nMetric'] + +# parameters=['Name', 'feature_channels', 'depth', 'Metric Value'] +# values=[['Model 1','32','3','.95'],['Model 2','24','4','.9'],['Model 3','16','3','.98'],['Model 4','16','3']] + + if len(order)1 else .5 for j in range(len(param_values[i]))]) + ax.set_yticklabels(param_values[i]) + #first parameter is the model name, keep the set_ticks + axes[0].yaxis.set_tick_params(width=1) + #last parameter is the metric, let the colorbar do the metric + axes[-1].yaxis.set_ticks_position('none') + axes[-1].set_yticklabels([]) + axes[-1].spines['left'].set_visible(False) + + #set the colorbar for the metric + if param_values[-1]: + max_metric=min_metric=float(param_values[-1][0]) + for metric_value in param_values[-1]: + min_metric=min(min_metric,float(metric_value)) + max_metric=max(max_metric,float(metric_value)) + else: + max_metric=min_metric=0 + + cmap = plt.get_cmap('viridis') # 'viridis' or 'rainbow' + sc = host.scatter([0,0], [0,0], s=[0,0], c=[min_metric, max_metric], cmap=cmap) + cbar = fig.colorbar(sc, ax=axes[-1], pad=0) + cbar.outline.set_visible(False) +# cbar.set_ticks([]) + + #set horizontal axe settings + host.set_xlim(0, len(parameters) - 1) + host.set_xticks(range(len(parameters))) + host.set_xticklabels(parameters) + host.tick_params(axis='x', which='major', pad=7) + host.spines['right'].set_visible(False) + host.xaxis.tick_top() + + + + from matplotlib.path import Path + import matplotlib.patches as patches + import numpy as np + for tokens in values: + values_num=[] + for i, token in enumerate(tokens): + if i1 else .5) + else: + values_num.append((float(token)-min_metric)/(max_metric-min_metric) if len(param_values[i])>1 and max_metric>min_metric else .5) + + # create bezier curves + # for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one + # at one third towards the next axis; the first and last axis have one less control vertex + # x-coordinate of the control vertices: at each integer (for the axes) and two inbetween + # y-coordinate: repeat every point three times, except the first and last only twice + verts = list(zip([x for x in np.linspace(0, len(values_num) - 1, len(values_num) * 3 - 2, endpoint=True)], + np.repeat(values_num, 3)[1:-1])) + # for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers + codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)] + path = Path(verts, codes) + patch = patches.PathPatch(path, facecolor='none', lw=1, edgecolor=cmap(values_num[-1]) if len(values_num)==len(parameters) else (0.33, 0.33, 0.33), zorder=values_num[-1] if len(values_num)==len(parameters) else -1) + host.add_patch(patch) + + plt.tight_layout(pad=0) + + canvas = plt.get_current_fig_manager().canvas + canvas.draw() + img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + plt.close() + return img + + +resolution = (256,256, 96) + +parameters=[] +values=[] +order=[] + + +app_socket = tc.connect() +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + + if msg_type == 'RequestDocumentation': + tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_parameters.__doc__))) + + if msg_type == 'SetResolution': + resolution = tc.decode_ints(msg_data) + + if msg_type == 'SetParameters': + parameters=tc.decode_strings(msg_data) + + if msg_type == 'ClearValues': + values = [] + if msg_type == 'AppendValues': + values.append(tc.decode_strings(msg_data)) + + if msg_type == 'SetOrder': + order=tc.decode_ints(msg_data) + + if msg_type == 'Render': + if resolution[0]>0 and resolution[1]>0: + img=plot_parameters(resolution[0:2],resolution[2],parameters,values,order) + tc.send_msg(app_socket, 'ImageData', tc.encode_image(img)) + + if msg_type == 'Exit': + break diff --git a/torchstudio/pythoninstall.cmd b/torchstudio/pythoninstall.cmd index 2267b56..3ba66ef 100644 --- a/torchstudio/pythoninstall.cmd +++ b/torchstudio/pythoninstall.cmd @@ -1,239 +1,239 @@ -# 2>nul&@goto :BATCH - -#BASH - -SCRIPTDIR=$(cd "$(dirname "$0")"; pwd) -( -cd "$SCRIPTDIR" -cd .. - -pythonpath="$(pwd)/python" -channel="pytorch" -cuda="" -packages="" -uninstall="" -while [ "$1" != "" ]; do - if [ $1 == "--path" ]; then - shift; pythonpath=$1 - elif [ $1 == "--channel" ]; then - shift; channel=$1 - elif [ $1 == "--cuda" ]; then - cuda="--cuda" - elif [ $1 == "--package" ]; then - shift; packages+="--package $1" - elif [ $1 == "--uninstall" ]; then - uninstall="1" - fi - shift -done - -if [ ! -z "$uninstall" ]; then - echo "Uninstalling python environment..." - rm -f *.sh - rm -f *.tmp - rm -f -r "$pythonpath" - if [ $? != 0 ]; then - echo "" 1>&2 - echo "Error while uninstalling. Check write permissions." 1>&2 - exit 1 - else - echo "" - echo "Uninstall complete." - exit 0 - fi -fi - -if [[ $OSTYPE == "linux"* ]]; then - echo "Downloading, installing and setting up a linux python environment" -elif [[ $OSTYPE == "darwin"* ]]; then - echo "Downloading, installing and setting up a macOS python environment" -else - echo "Error: unsupported OS ($OSTYPE)" 1>&2 - exit 1 -fi - -if [ ! -z "$cuda" ]; then - echo "This may take up to 16 minutes depending on your download speed, and up to 16 GB." -else - echo "This may take up to 5 minutes depending on your download speed, and up to 5 GB." -fi - -echo "" -if [[ $OSTYPE == "linux"* ]]; then - file=Miniconda3-latest-Linux-x86_64.sh - rm -f "$file.tmp" - if [ -f "$file" ]; then - echo "Python installer ($file) already downloaded" - else - echo "Downloading Python installer ($file)..." - wget --show-progress --progress=bar:force:noscroll --no-check-certificate https://repo.anaconda.com/miniconda/$file -O "$file.tmp" - if [ $? != 0 ]; then - rm -f "$file.tmp" - echo "" 1>&2 - echo "Error while downloading. Make sure port 80 is open." 1>&2 - exit 1 - else - mv "$file.tmp" "$file" - fi - fi -elif [[ $OSTYPE == "darwin"* ]]; then - if [ "$(uname -m)" == "arm64" ]; then - file=Miniconda3-latest-MacOSX-arm64.sh - else - file=Miniconda3-latest-MacOSX-x86_64.sh - fi - rm -f "$file.tmp" - if [ -f "$file" ]; then - echo "Python installer $file already downloaded" - else - echo "Downloading Python installer $file..." - curl --insecure https://repo.anaconda.com/miniconda/$file -o "$file.tmp" - if [ $? != 0 ]; then - rm -f "$file.tmp" - echo "" 1>&2 - echo "Error while downloading. Make sure port 80 is open." 1>&2 - exit 1 - else - mv "$file.tmp" "$file" - fi - fi -fi - -echo "" -if [ -f "python.tmp" ]; then - rm -f python.tmp - rm -f -r "$pythonpath" -fi -if [ -d "$pythonpath" ]; then - echo "Python already installed in $pythonpath" -else - echo "Installing Python in $pythonpath..." - echo "" > python.tmp - bash "$(pwd)/$file" -b -f -p "$pythonpath" - rm -f python.tmp - if [ $? != 0 ]; then - rm -f -r "$pythonpath" - echo "" 1>&2 - echo "Error while installing. Make sure you have write permissions." 1>&2 - exit 1 - fi -fi - -PATH="$PATH;$pythonpath/bin" -"$pythonpath/bin/python" -u -B -X utf8 -m torchstudio.pythoninstall --channel $channel $cuda $packages -if [ $? != 0 ]; then - echo "" 1>&2 - echo "Error while installing packages" 1>&2 - exit 1 -fi - -echo "" -echo "Installation complete." -) -exit - - -:BATCH - -@echo off -setlocal -cd /D "%~dp0" -cd .. - -set pythonpath=%cd%\python -set channel=pytorch -set cuda= -set packages= -set uninstall= -:args -if "%~1" == "--path" ( - set pythonpath=%~2 - shift -) else if "%~1" == "--channel" ( - set channel=%~2 - shift -) else if "%~1" == "--cuda" ( - set cuda=--cuda -) else if "%~1" == "--package" ( - set packages=%packages% --package %~2 - shift -) else if "%~1" == "--uninstall" ( - set uninstall=1 -) else if "%~1" == "" ( - goto endargs -) -shift -goto args -:endargs - -if DEFINED uninstall ( - echo Uninstalling python environment... - del *.exe 2>nul - del *.tmp 2>nul - rmdir /s /q "%pythonpath%" 2>nul - if ERRORLEVEL 1 ( - echo. 1>&2 - echo Error while uninstalling. Check write permissions. 1>&2 - exit /B 1 - ) else ( - echo. - echo Uninstall complete. - exit /B 0 - ) -) - -echo Downloading, installing and setting up a windows python environment -if DEFINED cuda ( - echo This may take up to 16 minutes depending on your download speed, and up to 16 GB. -) else ( - echo This may take up to 5 minutes depending on your download speed, and up to 5 GB. -) - -echo. -set file=Miniconda3-latest-Windows-x86_64.exe -del %file%.tmp 2>nul -if EXIST "%file%" ( - echo Python installer %file% already downloaded -) else ( - echo Downloading Python installer %file%... - curl --insecure https://repo.anaconda.com/miniconda/%file% -o %file%.tmp - if ERRORLEVEL 1 ( - del %file%.tmp 2>nul - echo. 1>&2 - echo Error while downloading. Make sure port 80 is open. 1>&2 - exit /B 1 - ) else ( - ren %file%.tmp %file% - ) -) - -echo. -if EXIST "python.tmp" ( - del python.tmp 2>nul - rmdir /s /q "%pythonpath%" 2>nul -) -if EXIST "%pythonpath%" ( - echo Python already installed in %pythonpath% -) else ( - echo Installing Python in %pythonpath%... - echo. > python.tmp - %file% /S /D=%pythonpath% - del python.tmp 2>nul - if ERRORLEVEL 1 ( - rmdir /s /q "%pythonpath%" 2>nul - echo. 1>&2 - echo Error while installing. Make sure you have write permissions. 1>&2 - exit /B 1 - ) -) - -set PATH=%PATH%;%pythonpath%;%pythonpath%\Library\mingw-w64\bin;%pythonpath%\Library\bin -"%pythonpath%\python" -u -B -X utf8 -m torchstudio.pythoninstall --channel %channel% %cuda% %packages% -if ERRORLEVEL 1 ( - echo. 1>&2 - echo Error while installing packages 1>&2 - exit /B 1 -) - -echo. -echo Installation complete. +# 2>nul&@goto :BATCH + +#BASH + +SCRIPTDIR=$(cd "$(dirname "$0")"; pwd) +( +cd "$SCRIPTDIR" +cd .. + +pythonpath="$(pwd)/python" +channel="pytorch" +cuda="" +packages="" +uninstall="" +while [ "$1" != "" ]; do + if [ $1 == "--path" ]; then + shift; pythonpath=$1 + elif [ $1 == "--channel" ]; then + shift; channel=$1 + elif [ $1 == "--cuda" ]; then + cuda="--cuda" + elif [ $1 == "--package" ]; then + shift; packages+="--package $1" + elif [ $1 == "--uninstall" ]; then + uninstall="1" + fi + shift +done + +if [ ! -z "$uninstall" ]; then + echo "Uninstalling python environment..." + rm -f *.sh + rm -f *.tmp + rm -f -r "$pythonpath" + if [ $? != 0 ]; then + echo "" 1>&2 + echo "Error while uninstalling. Check write permissions." 1>&2 + exit 1 + else + echo "" + echo "Uninstall complete." + exit 0 + fi +fi + +if [[ $OSTYPE == "linux"* ]]; then + echo "Downloading, installing and setting up a linux python environment" +elif [[ $OSTYPE == "darwin"* ]]; then + echo "Downloading, installing and setting up a macOS python environment" +else + echo "Error: unsupported OS ($OSTYPE)" 1>&2 + exit 1 +fi + +if [ ! -z "$cuda" ]; then + echo "This may take up to 16 minutes depending on your download speed, and up to 16 GB." +else + echo "This may take up to 5 minutes depending on your download speed, and up to 5 GB." +fi + +echo "" +if [[ $OSTYPE == "linux"* ]]; then + file=Miniconda3-latest-Linux-x86_64.sh + rm -f "$file.tmp" + if [ -f "$file" ]; then + echo "Python installer ($file) already downloaded" + else + echo "Downloading Python installer ($file)..." + wget --show-progress --progress=bar:force:noscroll --no-check-certificate https://repo.anaconda.com/miniconda/$file -O "$file.tmp" + if [ $? != 0 ]; then + rm -f "$file.tmp" + echo "" 1>&2 + echo "Error while downloading. Make sure port 80 is open." 1>&2 + exit 1 + else + mv "$file.tmp" "$file" + fi + fi +elif [[ $OSTYPE == "darwin"* ]]; then + if [ "$(uname -m)" == "arm64" ]; then + file=Miniconda3-latest-MacOSX-arm64.sh + else + file=Miniconda3-latest-MacOSX-x86_64.sh + fi + rm -f "$file.tmp" + if [ -f "$file" ]; then + echo "Python installer $file already downloaded" + else + echo "Downloading Python installer $file..." + curl --insecure https://repo.anaconda.com/miniconda/$file -o "$file.tmp" + if [ $? != 0 ]; then + rm -f "$file.tmp" + echo "" 1>&2 + echo "Error while downloading. Make sure port 80 is open." 1>&2 + exit 1 + else + mv "$file.tmp" "$file" + fi + fi +fi + +echo "" +if [ -f "python.tmp" ]; then + rm -f python.tmp + rm -f -r "$pythonpath" +fi +if [ -d "$pythonpath" ]; then + echo "Python already installed in $pythonpath" +else + echo "Installing Python in $pythonpath..." + echo "" > python.tmp + bash "$(pwd)/$file" -b -f -p "$pythonpath" + rm -f python.tmp + if [ $? != 0 ]; then + rm -f -r "$pythonpath" + echo "" 1>&2 + echo "Error while installing. Make sure you have write permissions." 1>&2 + exit 1 + fi +fi + +PATH="$PATH;$pythonpath/bin" +"$pythonpath/bin/python" -u -B -X utf8 -m torchstudio.pythoninstall --channel $channel $cuda $packages +if [ $? != 0 ]; then + echo "" 1>&2 + echo "Error while installing packages" 1>&2 + exit 1 +fi + +echo "" +echo "Installation complete." +) +exit + + +:BATCH + +@echo off +setlocal +cd /D "%~dp0" +cd .. + +set pythonpath=%cd%\python +set channel=pytorch +set cuda= +set packages= +set uninstall= +:args +if "%~1" == "--path" ( + set pythonpath=%~2 + shift +) else if "%~1" == "--channel" ( + set channel=%~2 + shift +) else if "%~1" == "--cuda" ( + set cuda=--cuda +) else if "%~1" == "--package" ( + set packages=%packages% --package %~2 + shift +) else if "%~1" == "--uninstall" ( + set uninstall=1 +) else if "%~1" == "" ( + goto endargs +) +shift +goto args +:endargs + +if DEFINED uninstall ( + echo Uninstalling python environment... + del *.exe 2>nul + del *.tmp 2>nul + rmdir /s /q "%pythonpath%" 2>nul + if ERRORLEVEL 1 ( + echo. 1>&2 + echo Error while uninstalling. Check write permissions. 1>&2 + exit /B 1 + ) else ( + echo. + echo Uninstall complete. + exit /B 0 + ) +) + +echo Downloading, installing and setting up a windows python environment +if DEFINED cuda ( + echo This may take up to 16 minutes depending on your download speed, and up to 16 GB. +) else ( + echo This may take up to 5 minutes depending on your download speed, and up to 5 GB. +) + +echo. +set file=Miniconda3-latest-Windows-x86_64.exe +del %file%.tmp 2>nul +if EXIST "%file%" ( + echo Python installer %file% already downloaded +) else ( + echo Downloading Python installer %file%... + curl --insecure https://repo.anaconda.com/miniconda/%file% -o %file%.tmp + if ERRORLEVEL 1 ( + del %file%.tmp 2>nul + echo. 1>&2 + echo Error while downloading. Make sure port 80 is open. 1>&2 + exit /B 1 + ) else ( + ren %file%.tmp %file% + ) +) + +echo. +if EXIST "python.tmp" ( + del python.tmp 2>nul + rmdir /s /q "%pythonpath%" 2>nul +) +if EXIST "%pythonpath%" ( + echo Python already installed in %pythonpath% +) else ( + echo Installing Python in %pythonpath%... + echo. > python.tmp + %file% /S /D=%pythonpath% + del python.tmp 2>nul + if ERRORLEVEL 1 ( + rmdir /s /q "%pythonpath%" 2>nul + echo. 1>&2 + echo Error while installing. Make sure you have write permissions. 1>&2 + exit /B 1 + ) +) + +set PATH=%PATH%;%pythonpath%;%pythonpath%\Library\mingw-w64\bin;%pythonpath%\Library\bin +"%pythonpath%\python" -u -B -X utf8 -m torchstudio.pythoninstall --channel %channel% %cuda% %packages% +if ERRORLEVEL 1 ( + echo. 1>&2 + echo Error while installing packages 1>&2 + exit /B 1 +) + +echo. +echo Installation complete. diff --git a/torchstudio/pythonparse.py b/torchstudio/pythonparse.py index e369276..9a64fe2 100644 --- a/torchstudio/pythonparse.py +++ b/torchstudio/pythonparse.py @@ -1,415 +1,414 @@ -#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490 -import os -os.environ['KMP_DUPLICATE_LIB_OK']='True' - -import importlib -import inspect, sys -import ast -import re -from typing import Dict, List -from os import listdir -from os.path import isfile, join -import torchstudio.tcpcodec as tc -from torchstudio.modules import safe_exec - -def gather_parameters(node): - params=[] - for param in inspect.signature(node).parameters.values(): - #name - if param.kind==param.VAR_POSITIONAL: - params.append("*"+param.name) - elif param.kind==param.VAR_KEYWORD: - params.append("**"+param.name) - else: - params.append(param.name) - #annotation - if param.annotation == param.empty: - params.append('') - else: - params.append(param.annotation.__name__ if isinstance(param.annotation, type) else repr(param.annotation)) - #default value - if param.default == param.empty: - params.append('') - elif inspect.isclass(param.default) or inspect.isfunction(param.default): - params.append(param.default.__module__+'.'+param.default.__name__) - else: - value=repr(param.default) - if "","") - params.append(value) - return params - -def gather_objects(module): - objects=[] - for name, obj in inspect.getmembers(module): -# print("INSPECTING ", name) - if ((inspect.isclass(obj) and hasattr(obj, '__mro__') and ("torch.nn.modules.module.Module" in str(obj.__mro__) or "torch.utils.data.dataset.Dataset" in str(obj.__mro__))) or inspect.isfunction(obj)): #filter unwanted torch objects - object={} - object['type']='class' if inspect.isclass(obj) else 'function' - object['name']=name - if obj.__doc__ is not None: - object['doc']=inspect.cleandoc(obj.__doc__) - - # autofill class members when requested - doc=object['doc'] - newstring = '' - start = 0 - for m in re.finditer(".. autoclass:: [a-zA-Z0-9_.]+\n :members:", doc): - end, newstart = m.span() - newstring += doc[start:end] - class_name = re.findall(".. autoclass:: ([a-zA-Z0-9_.]+)\n :members:", m.group(0)) - rep = class_name[0]+":\n" - sub_error_message, submodule = safe_exec(importlib.import_module,(class_name[0].rpartition('.')[0],)) - if submodule is not None: - for member in dir(vars(submodule)[class_name[0].rpartition('.')[-1]]): - if not member.startswith('_'): - rep+=' '+member+'\n' - newstring += rep - start = newstart - newstring += doc[start:] - object['doc']=newstring.replace(" :noindex:","") - else: - object['doc']=name+(' class' if inspect.isclass(obj) else ' function') - if hasattr(obj,'__getitem__') and obj.__getitem__.__doc__ is not None: - itemdoc=inspect.cleandoc(obj.__getitem__.__doc__) - if 'Returns:' in itemdoc: - object['doc']+='\n\n'+itemdoc[itemdoc.find('Returns:'):] - object['params']=gather_parameters(obj.__init__ if inspect.isclass(obj) else obj) - if inspect.isclass(obj): - object['params']=object['params'][3:] #remove self parameter - object['code']='' - objects.append(object) - return objects - - -def parse_parameters(node): - #prepare defaults to be in sync with arguments - defaults=[] - for d in node.args.defaults: - defaults.append(ast.get_source_segment(code,d)) - for d in range(len(node.args.args)-len(node.args.defaults)): - defaults.insert(0,"") - #scan through arguments - params=[] - for i,a in enumerate(node.args.args): - params.append(a.arg) - if a.annotation: - params.append(ast.get_source_segment(code,a.annotation)) - else: - params.append("") - params.append(defaults[i]) - #add *args, if applicable - if node.args.vararg: - params.append("*"+node.args.vararg.arg) - if node.args.vararg.annotation: - params.append(ast.get_source_segment(code,node.args.vararg.annotation)) - else: - params.append("") - params.append("") #no default value - #add **kwargs, if applicable - if node.args.kwarg: - params.append("**"+node.args.kwarg.arg) - if node.args.kwarg.annotation: - params.append(ast.get_source_segment(code,node.args.kwarg.annotation)) - else: - params.append("") - params.append("") #no default value - return params - -def parse_objects(module:ast.Module): - objects=[] - for node in module.body: - if isinstance(node, ast.FunctionDef): - object={} - object['code']=ast.get_source_segment(code,node) - object['type']='function' - object['name']=node.name - object['doc']=ast.get_docstring(node) if ast.get_docstring(node) else "" - object['params']=parse_parameters(node) - objects.append(object) - if isinstance(node, ast.ClassDef): - object={} - object['code']=ast.get_source_segment(code,node) - object['type']='class' - object['name']=node.name - object['doc']=ast.get_docstring(node) if ast.get_docstring(node) else "" - object['params']=[] - for subnode in node.body: - if isinstance(subnode, ast.FunctionDef) and subnode.name=="__init__": - object['params']=parse_parameters(subnode) - object['params']=object['params'][3:] #remove self parameter - objects.append(object) - return objects - -def filter_parent_objects(objects:List[Dict]) -> List[Dict]: - parent_objects=[] - for object in objects: - unique=True - for subobject in objects: - name=object['name'] - if subobject['name']!=name: - if re.search('[ =+]'+name+'[ ]*\(', subobject['code']): - unique=False - if unique: - parent_objects.append(object) - return parent_objects - - -generated_class="""\ -import typing -import pathlib -import torch -import torch.nn as nn -import torch.nn.functional as F -import {0} -from {0} import transforms - -class {1}({2}): - \"\"\"{3}\"\"\" - def __init__({4}): - super().__init__({5}) -""" - -generated_function="""\ -import typing -import pathlib -import torch -import torch.nn as nn -import torch.nn.functional as F -import {0} -from {0} import transforms - -def {1}({2}): - \"\"\"{3}\"\"\" - {4}={5}({6}) - return {4} -""" - -def generate_code(path,object): - #write an inherited code code for each object - name=object['name'] - params=object['params'] - if object['type']=='class': - return generated_class.format( - path.split('.')[0], - name, path+'.'+name, - object['doc'], - ', '.join(['self']+[params[i]+(': ' if params[i+1] else'')+params[i+1]+(' = ' if params[i+2] else'')+params[i+2] for i in range(0,len(params),3)]), - ', '.join([params[i] for i in range(0,len(params),3)]) - ) - else: - return generated_function.format( - path.split('.')[0], - name, ', '.join([params[i]+(': ' if params[i+1] else'')+params[i+1]+(' = ' if params[i+2] else'')+params[i+2] for i in range(0,len(params),3)]), - object['doc'], - path.split('.')[1][:-1], path+'.'+name, ', '.join([params[i] for i in range(0,len(params),3)]) - ) - -def patch_parameters(path, name, params): - patched_params=[] - if 'datasets' in path: - for state in ['train','valid']: - for i in range(0,len(params)-1,3): - #patch root for all modules - if params[i]=='root' and not params[i+2]: - params[i+2]="'"+data_path+"'" - - #patch download for all modules - if params[i]=="download" and params[i+2]=="False": - params[i+2]="True" - - #patch transform for vision modules - if 'torchvision' in path and (params[i]=="transform" or params[i]=="target_transform"): - params[i+2]="transforms.Compose([])" - - #patch train/val for specific modules - if state=='valid': - if params[i]=="train" and params[i+2]=="True": - params[i+2]="False" - if 'torchvision' in path: - if (name=="Cityscapes" or name=="ImageNet") and params[i]=="split": - params[i+2]="'val'" - if (name=="STL10" or name=="SVHN") and params[i]=="split": - params[i+2]="'test'" - if (name=="CelebA") and params[i]=="split": - params[i+2]="'valid'" - if (name=="Places365") and params[i]=="split": - params[i+2]="'val'" - if (name=="VOCDetection" or name=="VOCSegmentation") and params[i]=="image_set": - params[i+2]="'val'" - - patched_params.append(params.copy()) - else: - patched_params.append(params) - return patched_params - -custom_classes={} -custom_classes['Custom Dataset']="""\ -import torch -from torch.utils.data import Dataset - -class MyDataset(Dataset): - def __init__(self): - super().__init__() - - def __len__(self): - pass - - def __getitem__(self, idx): - pass -""" -custom_classes['Custom Renderer']="""\ -from torchstudio.modules import Renderer -import numpy as np -import matplotlib as mpl -import matplotlib.pyplot as plt -import PIL - -class MyRenderer(Renderer): - def __init__(self): - super().__init__() - - def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]): - pass -""" -custom_classes['Custom Analyzer']="""\ -from torchstudio.modules import Analyzer -from typing import List -import numpy as np -import matplotlib as mpl -import matplotlib.pyplot as plt -import PIL - -class MyAnalyzer(Analyzer): - def __init__(self): - super().__init__() - - def start_analysis(self, num_training_samples: int, num_validation_samples: int, input_tensors_id: List[int], output_tensors_id: List[int], labels: List[str]): - pass - - def analyze_sample(self, sample: List[np.array], training_sample: bool): - pass - - def finish_analysis(self): - pass - - def generate_report(self, size: Tuple[int, int], dpi: int): - pass -""" -custom_classes['Custom Model']="""\ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class MyModel(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - pass -""" -custom_classes['Custom Loss']="""\ -import torch.nn as nn - -class MyLoss(nn.Modules._Loss): - def __init__(self): - super().__init__() - - def forward(self, x): - pass -""" -custom_classes['Custom Metric']="""\ -from torchstudio.modules import Metric -import torch.nn.functional as F - -class MyMetric(Metric): - def __init__(self): - pass - - def update(self, preds, target): - pass - - def compute(self): - pass - - def reset(self): - pass -""" -custom_classes['Custom Optimizer']="""\ -import torch.optim as optim - -class MyOptimizer(optim.Optimizer): - def __init__(self, params): - super().__init__(params) -""" -custom_classes['Custom Scheduler']="""\ -import torch.optim as optim - -class MyScheduler(optim._LRScheduler): - def __init__(self, optimizer): - super().__init__(optimizer) -""" - -def scan_folder(path): - path=path.replace('.','/') - codes=[] - for filename in listdir(path): - if isfile(join(path, filename)): - with open(join(path, filename), "r") as file: - codes.append(file.read()) - return codes - - -app_socket = tc.connect() -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - objects=[] - if msg_type == 'SetDataDir': - data_path=tc.decode_strings(msg_data)[0] - - if msg_type == 'Parse': #parse code or path, return a list of objects (class and functions) with their names, doc, parameters, doc and code - decoded=tc.decode_strings(msg_data) - path=decoded[0] - if path in custom_classes: - decoded.append(custom_classes[path]) - if 'torchstudio' in path: - decoded+=scan_folder(path) - if len(decoded)>1: - #parse code chunks - for code in decoded[1:]: - error_msg, module = safe_exec(ast.parse,(code,)) - if error_msg is None and module is not None: - objects_batch=parse_objects(module) - objects_batch=filter_parent_objects(objects_batch) #only keep parent objets - for i in range(len(objects_batch)): - objects_batch[i]['code']=code #set whole source code for each object, as we don't know the dependencies - objects.extend(objects_batch) - else: - print("Error parsing code:", error_msg, "\n", file=sys.stderr) - else: - #parse module - error_msg, module = safe_exec(importlib.import_module,(path,)) - if error_msg is None and module is not None: - objects=gather_objects(module) - for i, object in enumerate(objects): - objects[i]['code']=generate_code(path,object) #generate inherited source code - else: - print("Error parsing module:", error_msg, "\n", file=sys.stderr) - - tc.send_msg(app_socket, 'ObjectsBegin', tc.encode_strings(path)) - for object in objects: - patched_params = patch_parameters(path,object['name'],object['params']) - for params in patched_params: - tc.send_msg(app_socket, 'Object', tc.encode_strings([path,object['code'],object['type'],object['name'],object['doc']]+params)) - tc.send_msg(app_socket, 'ObjectsEnd', tc.encode_strings(path)) - - if msg_type == 'RequestDefinitionName': #return default definition name - tab=tc.decode_strings(msg_data)[0] - if tab=='dataset': - tc.send_msg(app_socket, 'SetDefinitionName', tc.encode_strings(['torchvision.datasets','MNIST'])) - if tab=='model': - tc.send_msg(app_socket, 'SetDefinitionName', tc.encode_strings(['torchstudio.models','MNISTClassifier'])) - - if msg_type == 'Exit': - break +#workaround until Pytorch 1.12.1 is released: https://github.com/pytorch/pytorch/issues/78490 +import os +os.environ['KMP_DUPLICATE_LIB_OK']='True' + +import importlib +import inspect, sys +import ast +import re +from typing import Dict, List +from os import listdir +from os.path import isfile, join +import torchstudio.tcpcodec as tc +from torchstudio.modules import safe_exec + +def gather_parameters(node): + params=[] + for param in inspect.signature(node).parameters.values(): + #name + if param.kind==param.VAR_POSITIONAL: + params.append("*"+param.name) + elif param.kind==param.VAR_KEYWORD: + params.append("**"+param.name) + else: + params.append(param.name) + #annotation + if param.annotation == param.empty: + params.append('') + else: + params.append(param.annotation.__name__ if isinstance(param.annotation, type) else repr(param.annotation)) + #default value + if param.default == param.empty: + params.append('') + elif inspect.isclass(param.default) or inspect.isfunction(param.default): + params.append(param.default.__module__+'.'+param.default.__name__) + else: + value=repr(param.default) + if "","") + params.append(value) + return params + +def gather_objects(module): + objects=[] + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and hasattr(obj, '__mro__') and ("torch.nn.modules.module.Module" in str(obj.__mro__) or "torch.utils.data.dataset.Dataset" in str(obj.__mro__))) or (inspect.isfunction(obj) and "return" in obj.__annotations__ and inspect.isclass(obj.__annotations__["return"]) and "torch.nn.modules.module.Module" in str(obj.__annotations__["return"].__mro__)): #filter unwanted torch objects + object={} + object['type']='class' if inspect.isclass(obj) else 'function' + object['name']=name + if obj.__doc__ is not None: + object['doc']=inspect.cleandoc(obj.__doc__) + + # autofill class members when requested + doc=object['doc'] + newstring = '' + start = 0 + for m in re.finditer(".. autoclass:: [a-zA-Z0-9_.]+\n :members:", doc): + end, newstart = m.span() + newstring += doc[start:end] + class_name = re.findall(".. autoclass:: ([a-zA-Z0-9_.]+)\n :members:", m.group(0)) + rep = class_name[0]+":\n" + sub_error_message, submodule = safe_exec(importlib.import_module,(class_name[0].rpartition('.')[0],)) + if submodule is not None: + for member in dir(vars(submodule)[class_name[0].rpartition('.')[-1]]): + if not member.startswith('_'): + rep+=' '+member+'\n' + newstring += rep + start = newstart + newstring += doc[start:] + object['doc']=newstring.replace(" :noindex:","") + else: + object['doc']=name+(' class' if inspect.isclass(obj) else ' function') + if hasattr(obj,'__getitem__') and obj.__getitem__.__doc__ is not None: + itemdoc=inspect.cleandoc(obj.__getitem__.__doc__) + if 'Returns:' in itemdoc: + object['doc']+='\n\n'+itemdoc[itemdoc.find('Returns:'):] + object['params']=gather_parameters(obj.__init__ if inspect.isclass(obj) else obj) + if inspect.isclass(obj): + object['params']=object['params'][3:] #remove self parameter + object['code']='' + objects.append(object) + return objects + + +def parse_parameters(node): + #prepare defaults to be in sync with arguments + defaults=[] + for d in node.args.defaults: + defaults.append(ast.get_source_segment(code,d)) + for d in range(len(node.args.args)-len(node.args.defaults)): + defaults.insert(0,"") + #scan through arguments + params=[] + for i,a in enumerate(node.args.args): + params.append(a.arg) + if a.annotation: + params.append(ast.get_source_segment(code,a.annotation)) + else: + params.append("") + params.append(defaults[i]) + #add *args, if applicable + if node.args.vararg: + params.append("*"+node.args.vararg.arg) + if node.args.vararg.annotation: + params.append(ast.get_source_segment(code,node.args.vararg.annotation)) + else: + params.append("") + params.append("") #no default value + #add **kwargs, if applicable + if node.args.kwarg: + params.append("**"+node.args.kwarg.arg) + if node.args.kwarg.annotation: + params.append(ast.get_source_segment(code,node.args.kwarg.annotation)) + else: + params.append("") + params.append("") #no default value + return params + +def parse_objects(module:ast.Module): + objects=[] + for node in module.body: + if isinstance(node, ast.FunctionDef): + object={} + object['code']=ast.get_source_segment(code,node) + object['type']='function' + object['name']=node.name + object['doc']=ast.get_docstring(node) if ast.get_docstring(node) else "" + object['params']=parse_parameters(node) + objects.append(object) + if isinstance(node, ast.ClassDef): + object={} + object['code']=ast.get_source_segment(code,node) + object['type']='class' + object['name']=node.name + object['doc']=ast.get_docstring(node) if ast.get_docstring(node) else "" + object['params']=[] + for subnode in node.body: + if isinstance(subnode, ast.FunctionDef) and subnode.name=="__init__": + object['params']=parse_parameters(subnode) + object['params']=object['params'][3:] #remove self parameter + objects.append(object) + return objects + +def filter_parent_objects(objects:List[Dict]) -> List[Dict]: + parent_objects=[] + for object in objects: + unique=True + for subobject in objects: + name=object['name'] + if subobject['name']!=name: + if re.search('[ =+]'+name+'[ ]*\(', subobject['code']): + unique=False + if unique: + parent_objects.append(object) + return parent_objects + + +generated_class="""\ +import typing +import pathlib +import torch +import torch.nn as nn +import torch.nn.functional as F +import {0} +from {0} import transforms + +class {1}({2}): + \"\"\"{3}\"\"\" + def __init__({4}): + super().__init__({5}) +""" + +generated_function="""\ +import typing +import pathlib +import torch +import torch.nn as nn +import torch.nn.functional as F +import {0} +from {0} import transforms + +def {1}({2}): + \"\"\"{3}\"\"\" + {4}={5}({6}) + return {4} +""" + +def generate_code(path,object): + #write an inherited code code for each object + name=object['name'] + params=object['params'] + if object['type']=='class': + return generated_class.format( + path.split('.')[0], + name, path+'.'+name, + object['doc'], + ', '.join(['self']+[params[i]+(': ' if params[i+1] else'')+params[i+1]+(' = ' if params[i+2] else'')+params[i+2] for i in range(0,len(params),3)]), + ', '.join([params[i] for i in range(0,len(params),3)]) + ) + else: + return generated_function.format( + path.split('.')[0], + name, ', '.join([params[i]+(': ' if params[i+1] else'')+params[i+1]+(' = ' if params[i+2] else'')+params[i+2] for i in range(0,len(params),3)]), + object['doc'], + path.split('.')[1][:-1], path+'.'+name, ', '.join([params[i] for i in range(0,len(params),3)]) + ) + +def patch_parameters(path, name, params): + patched_params=[] + if 'datasets' in path: + for state in ['train','valid']: + for i in range(0,len(params)-1,3): + #patch root for all modules + if params[i]=='root' and not params[i+2]: + params[i+2]="'"+data_path+"'" + + #patch download for all modules + if params[i]=="download" and params[i+2]=="False": + params[i+2]="True" + + #patch transform for vision modules + if 'torchvision' in path and (params[i]=="transform" or params[i]=="target_transform"): + params[i+2]="transforms.Compose([])" + + #patch train/val for specific modules + if state=='valid': + if params[i]=="train" and params[i+2]=="True": + params[i+2]="False" + if 'torchvision' in path: + if (name=="Cityscapes" or name=="ImageNet") and params[i]=="split": + params[i+2]="'val'" + if (name=="STL10" or name=="SVHN") and params[i]=="split": + params[i+2]="'test'" + if (name=="CelebA") and params[i]=="split": + params[i+2]="'valid'" + if (name=="Places365") and params[i]=="split": + params[i+2]="'val'" + if (name=="VOCDetection" or name=="VOCSegmentation") and params[i]=="image_set": + params[i+2]="'val'" + + patched_params.append(params.copy()) + else: + patched_params.append(params) + return patched_params + +custom_classes={} +custom_classes['Custom Dataset']="""\ +import torch +from torch.utils.data import Dataset + +class MyDataset(Dataset): + def __init__(self): + super().__init__() + + def __len__(self): + pass + + def __getitem__(self, idx): + pass +""" +custom_classes['Custom Renderer']="""\ +from torchstudio.modules import Renderer +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt +import PIL + +class MyRenderer(Renderer): + def __init__(self): + super().__init__() + + def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]): + pass +""" +custom_classes['Custom Analyzer']="""\ +from torchstudio.modules import Analyzer +from typing import List +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt +import PIL + +class MyAnalyzer(Analyzer): + def __init__(self): + super().__init__() + + def start_analysis(self, num_training_samples: int, num_validation_samples: int, input_tensors_id: List[int], output_tensors_id: List[int], labels: List[str]): + pass + + def analyze_sample(self, sample: List[np.array], training_sample: bool): + pass + + def finish_analysis(self): + pass + + def generate_report(self, size: Tuple[int, int], dpi: int): + pass +""" +custom_classes['Custom Model']="""\ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + pass +""" +custom_classes['Custom Loss']="""\ +import torch.nn as nn + +class MyLoss(nn.Modules._Loss): + def __init__(self): + super().__init__() + + def forward(self, x): + pass +""" +custom_classes['Custom Metric']="""\ +from torchstudio.modules import Metric +import torch.nn.functional as F + +class MyMetric(Metric): + def __init__(self): + pass + + def update(self, preds, target): + pass + + def compute(self): + pass + + def reset(self): + pass +""" +custom_classes['Custom Optimizer']="""\ +import torch.optim as optim + +class MyOptimizer(optim.Optimizer): + def __init__(self, params): + super().__init__(params) +""" +custom_classes['Custom Scheduler']="""\ +import torch.optim as optim + +class MyScheduler(optim._LRScheduler): + def __init__(self, optimizer): + super().__init__(optimizer) +""" + +def scan_folder(path): + path=path.replace('.','/') + codes=[] + for filename in listdir(path): + if isfile(join(path, filename)): + with open(join(path, filename), "r") as file: + codes.append(file.read()) + return codes + + +app_socket = tc.connect() +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + objects=[] + if msg_type == 'SetDataDir': + data_path=tc.decode_strings(msg_data)[0] + + if msg_type == 'Parse': #parse code or path, return a list of objects (class and functions) with their names, doc, parameters, doc and code + decoded=tc.decode_strings(msg_data) + path=decoded[0] + if path in custom_classes: + decoded.append(custom_classes[path]) + if 'torchstudio' in path: + decoded+=scan_folder(path) + if len(decoded)>1: + #parse code chunks + for code in decoded[1:]: + error_msg, module = safe_exec(ast.parse,(code,)) + if error_msg is None and module is not None: + objects_batch=parse_objects(module) + objects_batch=filter_parent_objects(objects_batch) #only keep parent objets + for i in range(len(objects_batch)): + objects_batch[i]['code']=code #set whole source code for each object, as we don't know the dependencies + objects.extend(objects_batch) + else: + print("Error parsing code:", error_msg, "\n", file=sys.stderr) + else: + #parse module + error_msg, module = safe_exec(importlib.import_module,(path,)) + if error_msg is None and module is not None: + objects=gather_objects(module) + for i, object in enumerate(objects): + objects[i]['code']=generate_code(path,object) #generate inherited source code + else: + print("Error parsing module:", error_msg, "\n", file=sys.stderr) + + tc.send_msg(app_socket, 'ObjectsBegin', tc.encode_strings(path)) + for object in objects: + patched_params = patch_parameters(path,object['name'],object['params']) + for params in patched_params: + tc.send_msg(app_socket, 'Object', tc.encode_strings([path,object['code'],object['type'],object['name'],object['doc']]+params)) + tc.send_msg(app_socket, 'ObjectsEnd', tc.encode_strings(path)) + + if msg_type == 'RequestDefinitionName': #return default definition name + tab=tc.decode_strings(msg_data)[0] + if tab=='dataset': + tc.send_msg(app_socket, 'SetDefinitionName', tc.encode_strings(['torchvision.datasets','MNIST'])) + if tab=='model': + tc.send_msg(app_socket, 'SetDefinitionName', tc.encode_strings(['torchstudio.models','MNISTClassifier'])) + + if msg_type == 'Exit': + break diff --git a/torchstudio/sshtunnel.py b/torchstudio/sshtunnel.py index 59bc2a6..5a9c551 100644 --- a/torchstudio/sshtunnel.py +++ b/torchstudio/sshtunnel.py @@ -1,337 +1,340 @@ -import time -import sys -import os -import io - -# Port forwarding from https://github.com/skyleronken/sshrat/blob/master/tunnels.py -# improved with dynamic local port allocation feedback for reverse tunnel with a null local port -import threading -import socket -import selectors -import time -import socketserver -import paramiko -import hashlib - -class Tunnel(): - - def __init__(self, ssh_session, tun_type, lhost, lport, dhost, dport): - self.tun_type = tun_type - self.lhost = lhost - self.lport = lport - self.dhost = dhost - self.dport = dport - - # create tunnel here - if self.tun_type == ForwardTunnel: - self.tunnel = ForwardTunnel(ssh_session, self.lhost, self.lport, self.dhost, self.dport) - elif self.tun_type == ReverseTunnel: - self.tunnel = ReverseTunnel(ssh_session, self.lhost, self.lport, self.dhost, self.dport) - self.lport = self.tunnel.lport #in case of dynamic allocation (lport=0) - - def to_str(self): - if self.tun_type == ForwardTunnel: - return f"{self.lhost}:{self.lport} --> {self.dhost}:{self.dport}" - else: - return f"{self.dhost}:{self.dport} <-- {self.lhost}:{self.lport}" - - def stop(self): - self.tunnel.stop() - -class ReverseTunnel(): - - def __init__(self, ssh_session, lhost, lport, dhost, dport): - self.session = ssh_session - self.lhost = lhost - self.lport = lport - self.dhost = dhost - self.dport = dport - - self.transport = ssh_session.get_transport() - - self.reverse_forward_tunnel(lhost, lport, dhost, dport, self.transport) - self.handlers = [] - - def stop(self): - self.transport.cancel_port_forward(self.lhost, self.lport) - for thr in self.handlers: - thr.stop() - - def handler(self, rev_socket, origin, laddress): - rev_handler = ReverseTunnelHandler(rev_socket, self.dhost, self.dport, self.lhost, self.lport) - rev_handler.setDaemon(True) - rev_handler.start() - self.handlers.append(rev_handler) - - def reverse_forward_tunnel(self, lhost, lport, dhost, dport, transport): - try: - self.lport=transport.request_port_forward(lhost, lport, handler=self.handler) - except Exception as e: - raise e - -class ReverseTunnelHandler(threading.Thread): - - def __init__(self, rev_socket, dhost, dport, lhost, lport): - - threading.Thread.__init__(self) - - self.rev_socket = rev_socket - self.dhost = dhost - self.dport = dport - self.lhost = lhost - self.lport = lport - - self.dst_socket = socket.socket() - try: - self.dst_socket.connect((self.dhost, self.dport)) - except Exception as e: - raise e - - self.keepalive = True - - def _read_from_rev(self, dst, rev): - self._transfer_data(src_socket=rev,dest_socket=dst) - - def _read_from_dest(self, dst, rev): - self._transfer_data(src_socket=dst,dest_socket=rev) - - def _transfer_data(self,src_socket,dest_socket): - dest_socket.setblocking(True) - data = src_socket.recv(1024) - - if len(data): - try: - dest_socket.send(data) - except Exception as e: - print(f"ssh error: {type(e).__name__}", file=sys.stderr) - - def stop(self): - self.rev_socket.shutdown(2) - self.dst_socket.shutdown(2) - self.rev_socket.close() - self.dst_socket.close() - self.keepalive = False - - def run(self): - selector = selectors.DefaultSelector() - - selector.register(fileobj=self.rev_socket,events=selectors.EVENT_READ,data=self._read_from_rev) - selector.register(fileobj=self.dst_socket,events=selectors.EVENT_READ,data=self._read_from_dest) - - while self.keepalive: - events = selector.select(5) - if len(events) > 0: - for key, _ in events: - callback = key.data - try: - callback(dst=self.dst_socket,rev=self.rev_socket) - except Exception as e: - print(f"ssh error: {type(e).__name__}", file=sys.stderr) - time.sleep(0) - - - -# credits to paramiko-tunnel -class ForwardTunnel(socketserver.ThreadingTCPServer): - daemon_threads = True - allow_reuse_address = True - - def __init__(self, ssh_session, lhost, lport, dhost, dport): - self.session = ssh_session - self.lhost = lhost - self.lport = lport - self.dhost = dhost - self.dport = dport - - super().__init__( - server_address=(lhost, lport), - RequestHandlerClass=ForwardTunnelHandler, - bind_and_activate=True, - ) - - self.baddr, self.bport = self.server_address - self.thread = threading.Thread( - target=self.serve_forever, - daemon=True, - ) - - self.start() - - def start(self): - self.thread.start() - - def stop(self): - self.shutdown() - self.server_close() - - def __enter__(self): - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.shutdown() - -class ForwardTunnelHandler(socketserver.BaseRequestHandler): - sz_buf = 1024 - - def __init__(self, request, cli_addr, server): - self.selector = selectors.DefaultSelector() - self.channel = None - super().__init__(request, cli_addr, server) - - def _read_from_client(self, sock, mask): - self._transfer_data(src_socket=sock, dest_socket=self.channel) - - def _read_from_channel(self, sock, mask): - self._transfer_data(src_socket=sock,dest_socket=self.request) - - def _transfer_data(self,src_socket,dest_socket): - src_socket.setblocking(True) - data = src_socket.recv(self.sz_buf) - - if len(data): - try: - dest_socket.send(data) - except BrokenPipeError: - self.finish() - - def handle(self): - peer_name = self.request.getpeername() - try: - self.channel = self.server.session.get_transport().open_channel( - kind='direct-tcpip', - dest_addr=(self.server.dhost,self.server.dport,), - src_addr=peer_name, - ) - except Exception as error: - msg = f'Connection failed to {self.server.dhost}:{self.server.dport}' - raise Exception(msg) - - else: - self.selector.register(fileobj=self.channel,events=selectors.EVENT_READ,data=self._read_from_channel) - self.selector.register(fileobj=self.request,events=selectors.EVENT_READ,data=self._read_from_client) - - if self.channel is None: - self.finish() - raise Exception(f'SSH Server rejected request to {self.server.dhost}:{self.server.dport}') - - while True: - events = self.selector.select() - for key, mask in events: - callback = key.data - callback(sock=key.fileobj,mask=mask) - if self.server._BaseServer__is_shut_down.is_set(): - self.finish() - time.sleep(0) - - def finish(self): - if self.channel is not None: - self.channel.shutdown(how=2) - self.channel.close() - self.request.shutdown(2) - self.request.close() - -### - - - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--sshaddress", help="server address", type=str, default=None) - parser.add_argument("--sshport", help="ssh server port", type=int, default=22) - parser.add_argument("--username", help="ssh server username", type=str, default=None) - parser.add_argument("--password", help="ssh server password", type=str, default=None) - parser.add_argument("--keyfile", help="ssh server key file", type=str, default=None) - parser.add_argument('--clean', help="clean all files", action='store_true', default=False) - parser.add_argument("--command", help="command to execute or run python scripts", type=str, default=None) - parser.add_argument("--script", help="python script to be launched on the server", type=str, default=None) - parser.add_argument("--address", help="address to which the script must connect", type=str, default=None) - parser.add_argument("--port", help="port to which the script must connect", type=int, default=None) - args, other_args = parser.parse_known_args() - - sshclient = paramiko.SSHClient() - sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - print("Connecting to remote server...", file=sys.stderr) - try: - sshclient.connect(hostname=args.sshaddress, port=args.sshport, username=args.username, password=args.password, pkey=paramiko.RSAKey.from_private_key_file(args.keyfile) if args.keyfile else None, timeout=5) - except: - print("Error: could not connect to remote server", file=sys.stderr) - exit(1) - - if args.clean: - print("Cleaning...", file=sys.stderr) - stdin, stdout, stderr = sshclient.exec_command('rm -r -f TorchStudio') - exit_status = stdout.channel.recv_exit_status() - stdin, stdout, stderr = sshclient.exec_command('rmdir /s /q TorchStudio') - exit_status = stdout.channel.recv_exit_status() - sshclient.close() - print("Cleaning complete") - exit(0) - - #copy root scripts to the remote server if necessary - print("Updating remote scripts...", file=sys.stderr) - local_scripts_hash = hashlib.md5() - for entry in os.scandir('torchstudio'): - if entry.is_file(): - with open('torchstudio/'+entry.name, 'rb') as f: - local_scripts_hash.update(f.read()) - local_scripts_hash = local_scripts_hash.digest() - sftp = paramiko.SFTPClient.from_transport(sshclient.get_transport()) - remote_scripts_hash=io.BytesIO() - try: - sftp.getfo('TorchStudio/torchstudio/.md5', remote_scripts_hash) - except: - pass - if remote_scripts_hash.getvalue()!=local_scripts_hash: - try: - sftp.mkdir('TorchStudio') - except: - pass - try: - sftp.mkdir('TorchStudio/torchstudio') - except: - pass - try: - for entry in os.scandir('torchstudio'): - if entry.is_file(): - sftp.put('torchstudio/'+entry.name, 'TorchStudio/torchstudio/'+entry.name) - if entry.name.endswith('.cmd'): - sftp.chmod('TorchStudio/torchstudio/'+entry.name, 0o0777) - except: - print("Error: could not update remote scripts", file=sys.stderr) - exit(1) - new_scripts_hash=io.BytesIO(local_scripts_hash) - sftp.putfo(new_scripts_hash, 'TorchStudio/torchstudio/.md5') - sftp.close() - - if args.address: - other_args=["--address", args.address]+other_args - - if args.port: - #setup remote port forwarding - print("Forwarding ports...", file=sys.stderr) - reverse_tunnel = Tunnel(sshclient, ReverseTunnel, 'localhost', 0, args.address if args.address else 'localhost', args.port) #remote address, remote port, local address, local port - other_args=["--port", str(reverse_tunnel.lport)]+other_args - - if args.command: - if args.script: - print("Launching remote script...", file=sys.stderr) - stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+args.command+" -u -X utf8 -m "+' '.join([args.script]+other_args)) - else: - print("Executing remote command...", file=sys.stderr) - stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+' '.join([args.command]+other_args)) - while True: - time.sleep(.1) - if stdout.channel.recv_ready(): - sys.stdout.write(str(stdout.channel.recv(1024),'utf-8')) - if stdout.channel.recv_stderr_ready(): - sys.stderr.write(str(stdout.channel.recv_stderr(1024),'utf-8')) - if stdout.channel.exit_status_ready(): - break - else: - print("Error: no python command set. Define a command or refresh to install a python environment.", file=sys.stderr) - - sshclient.close() - +import time +import sys +import os +import io + +# Port forwarding from https://github.com/skyleronken/sshrat/blob/master/tunnels.py +# improved with dynamic local port allocation feedback for reverse tunnel with a null local port +import threading +import socket +import selectors +import time +import socketserver +import paramiko +import hashlib + +class Tunnel(): + + def __init__(self, ssh_session, tun_type, lhost, lport, dhost, dport): + self.tun_type = tun_type + self.lhost = lhost + self.lport = lport + self.dhost = dhost + self.dport = dport + + # create tunnel here + if self.tun_type == ForwardTunnel: + self.tunnel = ForwardTunnel(ssh_session, self.lhost, self.lport, self.dhost, self.dport) + elif self.tun_type == ReverseTunnel: + self.tunnel = ReverseTunnel(ssh_session, self.lhost, self.lport, self.dhost, self.dport) + self.lport = self.tunnel.lport #in case of dynamic allocation (lport=0) + + def to_str(self): + if self.tun_type == ForwardTunnel: + return f"{self.lhost}:{self.lport} --> {self.dhost}:{self.dport}" + else: + return f"{self.dhost}:{self.dport} <-- {self.lhost}:{self.lport}" + + def stop(self): + self.tunnel.stop() + +class ReverseTunnel(): + + def __init__(self, ssh_session, lhost, lport, dhost, dport): + self.session = ssh_session + self.lhost = lhost + self.lport = lport + self.dhost = dhost + self.dport = dport + + self.transport = ssh_session.get_transport() + + self.reverse_forward_tunnel(lhost, lport, dhost, dport, self.transport) + self.handlers = [] + + def stop(self): + self.transport.cancel_port_forward(self.lhost, self.lport) + for thr in self.handlers: + thr.stop() + + def handler(self, rev_socket, origin, laddress): + rev_handler = ReverseTunnelHandler(rev_socket, self.dhost, self.dport, self.lhost, self.lport) + rev_handler.setDaemon(True) + rev_handler.start() + self.handlers.append(rev_handler) + + def reverse_forward_tunnel(self, lhost, lport, dhost, dport, transport): + try: + self.lport=transport.request_port_forward(lhost, lport, handler=self.handler) + except Exception as e: + raise e + +class ReverseTunnelHandler(threading.Thread): + + def __init__(self, rev_socket, dhost, dport, lhost, lport): + + threading.Thread.__init__(self) + + self.rev_socket = rev_socket + self.dhost = dhost + self.dport = dport + self.lhost = lhost + self.lport = lport + + self.dst_socket = socket.socket() + try: + self.dst_socket.connect((self.dhost, self.dport)) + except Exception as e: + raise e + + self.keepalive = True + + def _read_from_rev(self, dst, rev): + self._transfer_data(src_socket=rev,dest_socket=dst) + + def _read_from_dest(self, dst, rev): + self._transfer_data(src_socket=dst,dest_socket=rev) + + def _transfer_data(self,src_socket,dest_socket): + dest_socket.setblocking(True) + data = src_socket.recv(1024) + + if len(data): + try: + dest_socket.send(data) + except Exception as e: + print(f"ssh error: {type(e).__name__}", file=sys.stderr) + + def stop(self): + self.rev_socket.shutdown(2) + self.dst_socket.shutdown(2) + self.rev_socket.close() + self.dst_socket.close() + self.keepalive = False + + def run(self): + selector = selectors.DefaultSelector() + + selector.register(fileobj=self.rev_socket,events=selectors.EVENT_READ,data=self._read_from_rev) + selector.register(fileobj=self.dst_socket,events=selectors.EVENT_READ,data=self._read_from_dest) + + while self.keepalive: + events = selector.select(5) + if len(events) > 0: + for key, _ in events: + callback = key.data + try: + callback(dst=self.dst_socket,rev=self.rev_socket) + except Exception as e: + print(f"ssh error: {type(e).__name__}", file=sys.stderr) + time.sleep(0) + + + +# credits to paramiko-tunnel +class ForwardTunnel(socketserver.ThreadingTCPServer): + daemon_threads = True + allow_reuse_address = True + + def __init__(self, ssh_session, lhost, lport, dhost, dport): + self.session = ssh_session + self.lhost = lhost + self.lport = lport + self.dhost = dhost + self.dport = dport + + super().__init__( + server_address=(lhost, lport), + RequestHandlerClass=ForwardTunnelHandler, + bind_and_activate=True, + ) + + self.baddr, self.bport = self.server_address + self.thread = threading.Thread( + target=self.serve_forever, + daemon=True, + ) + + self.start() + + def start(self): + self.thread.start() + + def stop(self): + self.shutdown() + self.server_close() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + +class ForwardTunnelHandler(socketserver.BaseRequestHandler): + sz_buf = 1024 + + def __init__(self, request, cli_addr, server): + self.selector = selectors.DefaultSelector() + self.channel = None + super().__init__(request, cli_addr, server) + + def _read_from_client(self, sock, mask): + self._transfer_data(src_socket=sock, dest_socket=self.channel) + + def _read_from_channel(self, sock, mask): + self._transfer_data(src_socket=sock,dest_socket=self.request) + + def _transfer_data(self,src_socket,dest_socket): + src_socket.setblocking(True) + data = src_socket.recv(self.sz_buf) + + if len(data): + try: + dest_socket.send(data) + except BrokenPipeError: + self.finish() + + def handle(self): + peer_name = self.request.getpeername() + try: + self.channel = self.server.session.get_transport().open_channel( + kind='direct-tcpip', + dest_addr=(self.server.dhost,self.server.dport,), + src_addr=peer_name, + ) + except Exception as error: + msg = f'Connection failed to {self.server.dhost}:{self.server.dport}' + raise Exception(msg) + + else: + self.selector.register(fileobj=self.channel,events=selectors.EVENT_READ,data=self._read_from_channel) + self.selector.register(fileobj=self.request,events=selectors.EVENT_READ,data=self._read_from_client) + + if self.channel is None: + self.finish() + raise Exception(f'SSH Server rejected request to {self.server.dhost}:{self.server.dport}') + + while True: + events = self.selector.select() + for key, mask in events: + callback = key.data + callback(sock=key.fileobj,mask=mask) + if self.server._BaseServer__is_shut_down.is_set(): + self.finish() + time.sleep(0) + + def finish(self): + if self.channel is not None: + self.channel.shutdown(how=2) + self.channel.close() + self.request.shutdown(2) + self.request.close() + +### + + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--sshaddress", help="server address", type=str, default=None) + parser.add_argument("--sshport", help="ssh server port", type=int, default=22) + parser.add_argument("--username", help="ssh server username", type=str, default=None) + parser.add_argument("--password", help="ssh server password", type=str, default=None) + parser.add_argument("--keyfile", help="ssh server key file", type=str, default=None) + parser.add_argument('--clean', help="clean all files", action='store_true', default=False) + parser.add_argument("--command", help="command to execute or run python scripts", type=str, default=None) + parser.add_argument("--script", help="python script to be launched on the server", type=str, default=None) + parser.add_argument("--address", help="address to which the script must connect", type=str, default=None) + parser.add_argument("--port", help="port to which the script must connect", type=int, default=None) + args, other_args = parser.parse_known_args() + + sshclient = paramiko.SSHClient() + sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + print("Connecting to remote server...", file=sys.stderr) + try: + sshclient.connect(hostname=args.sshaddress, port=args.sshport, username=args.username, password=args.password, pkey=paramiko.RSAKey.from_private_key_file(args.keyfile) if args.keyfile else None, timeout=5) + except: + print("Error: could not connect to remote server", file=sys.stderr) + exit(1) + + if args.clean: + print("Cleaning...", file=sys.stderr) + stdin, stdout, stderr = sshclient.exec_command('rm -r -f TorchStudio') + exit_status = stdout.channel.recv_exit_status() + stdin, stdout, stderr = sshclient.exec_command('rmdir /s /q TorchStudio') + exit_status = stdout.channel.recv_exit_status() + sshclient.close() + print("Cleaning complete") + exit(0) + + #copy root scripts to the remote server if necessary + print("Updating remote scripts...", file=sys.stderr) + local_scripts_hash = hashlib.md5() + for entry in os.scandir('torchstudio'): + if entry.is_file(): + with open('torchstudio/'+entry.name, 'rb') as f: + local_scripts_hash.update(f.read()) + local_scripts_hash = local_scripts_hash.digest() + sftp = paramiko.SFTPClient.from_transport(sshclient.get_transport()) + remote_scripts_hash=io.BytesIO() + try: + sftp.getfo('TorchStudio/torchstudio/.md5', remote_scripts_hash) + except: + pass + if remote_scripts_hash.getvalue()!=local_scripts_hash: + try: + sftp.mkdir('TorchStudio') + except: + pass + try: + sftp.mkdir('TorchStudio/torchstudio') + except: + pass + try: + for entry in os.scandir('torchstudio'): + if entry.is_file(): + sftp.put('torchstudio/'+entry.name, 'TorchStudio/torchstudio/'+entry.name) + if entry.name.endswith('.cmd'): + sftp.chmod('TorchStudio/torchstudio/'+entry.name, 0o0777) + except: + print("Error: could not update remote scripts", file=sys.stderr) + exit(1) + new_scripts_hash=io.BytesIO(local_scripts_hash) + sftp.putfo(new_scripts_hash, 'TorchStudio/torchstudio/.md5') + sftp.close() + + if args.address: + other_args=["--address", args.address]+other_args + + if args.port: + #setup remote port forwarding + print("Forwarding ports...", file=sys.stderr) + reverse_tunnel = Tunnel(sshclient, ReverseTunnel, 'localhost', 0, args.address if args.address else 'localhost', args.port) #remote address, remote port, local address, local port + other_args=["--port", str(reverse_tunnel.lport)]+other_args + + if args.command: + if args.script: + print("Launching remote script...", file=sys.stderr) + stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+args.command+" -u -X utf8 -m "+' '.join([args.script]+other_args)) + else: + print("Executing remote command...", file=sys.stderr) + stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+' '.join([args.command]+other_args)) + while True: + time.sleep(.1) + if stdout.channel.recv_stderr_ready(): + sys.stderr.write(str(stdout.channel.recv_stderr(1024).replace(b'\r\n',b'\n'),'utf-8')) + if stdout.channel.recv_ready(): + sys.stdout.write(str(stdout.channel.recv(1024).replace(b'\r\n',b'\n'),'utf-8')) + if stdout.channel.exit_status_ready(): + break + else: + if args.script: + print("Error: no python environment set.", file=sys.stderr) + else: + print("Error: no command set.", file=sys.stderr) + + sshclient.close() + diff --git a/torchstudio/tensorrender.py b/torchstudio/tensorrender.py index 510a9e4..15703e9 100644 --- a/torchstudio/tensorrender.py +++ b/torchstudio/tensorrender.py @@ -1,70 +1,70 @@ -import torchstudio.tcpcodec as tc -from torchstudio.modules import safe_exec -import inspect -import sys -import os - -title = '' -tensor = None -resolution = (256,256, 96) -shift = (0,0,0,0) -scale = (1,1,1,1) -input_tensors = [] -target_tensor = None -labels = [] - -app_socket = tc.connect() -while True: - msg_type, msg_data = tc.recv_msg(app_socket) - - if msg_type == 'SetRendererCode': - error_msg, renderer_env = safe_exec(tc.decode_strings(msg_data)[0],description='renderer definition') - if error_msg is not None or 'renderer' not in renderer_env: - print("Unknown renderer definition error" if error_msg is None else error_msg, file=sys.stderr) - else: - tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(renderer_env['renderer'].__doc__) if renderer_env['renderer'].__doc__ is not None else "")) - - if msg_type == 'Clear': - tensor = None - input_tensors = [] - target_tensor = None - - if msg_type == 'SetTitle': - title = tc.decode_strings(msg_data)[0] - - if msg_type == 'TensorData': - tensor = tc.decode_numpy_tensors(msg_data)[0] - - if msg_type == 'SetResolution': - resolution = tc.decode_ints(msg_data) - - if msg_type == 'SetShift': - shift = tc.decode_floats(msg_data) - if msg_type == 'SetScale': - scale = tc.decode_floats(msg_data) - - if msg_type == 'SetInputTensors': - input_tensors = tc.decode_numpy_tensors(msg_data) - - if msg_type == 'SetTargetTensors': - target_tensors = tc.decode_numpy_tensors(msg_data) - if target_tensors: - target_tensor=target_tensors[0] - else: - target_tensor=None - - if msg_type == 'SetLabels': - labels = tc.decode_strings(msg_data) - - if msg_type == 'Render': - if 'renderer' in renderer_env and tensor is not None and resolution[0]>0 and resolution[1]>0: - error_msg, img = safe_exec(renderer_env['renderer'].render, (title, tensor,resolution[0:2],resolution[2],shift,scale,input_tensors,target_tensor,labels), description='renderer definition') - if error_msg is not None: - print(error_msg, file=sys.stderr) - if img is None: - tc.send_msg(app_socket, 'ImageError') - else: - tc.send_msg(app_socket, 'ImageData', tc.encode_image(img)) - - if msg_type == 'Exit': - break +import torchstudio.tcpcodec as tc +from torchstudio.modules import safe_exec +import inspect +import sys +import os + +title = '' +tensor = None +resolution = (256,256, 96) +shift = (0,0,0,0) +scale = (1,1,1,1) +input_tensors = [] +target_tensor = None +labels = [] + +app_socket = tc.connect() +while True: + msg_type, msg_data = tc.recv_msg(app_socket) + + if msg_type == 'SetRendererCode': + error_msg, renderer_env = safe_exec(tc.decode_strings(msg_data)[0],description='renderer definition') + if error_msg is not None or 'renderer' not in renderer_env: + print("Unknown renderer definition error" if error_msg is None else error_msg, file=sys.stderr) + else: + tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(renderer_env['renderer'].__doc__) if renderer_env['renderer'].__doc__ is not None else "")) + + if msg_type == 'Clear': + tensor = None + input_tensors = [] + target_tensor = None + + if msg_type == 'SetTitle': + title = tc.decode_strings(msg_data)[0] + + if msg_type == 'TensorData': + tensor = tc.decode_numpy_tensors(msg_data)[0] + + if msg_type == 'SetResolution': + resolution = tc.decode_ints(msg_data) + + if msg_type == 'SetShift': + shift = tc.decode_floats(msg_data) + if msg_type == 'SetScale': + scale = tc.decode_floats(msg_data) + + if msg_type == 'SetInputTensors': + input_tensors = tc.decode_numpy_tensors(msg_data) + + if msg_type == 'SetTargetTensors': + target_tensors = tc.decode_numpy_tensors(msg_data) + if target_tensors: + target_tensor=target_tensors[0] + else: + target_tensor=None + + if msg_type == 'SetLabels': + labels = tc.decode_strings(msg_data) + + if msg_type == 'Render': + if 'renderer' in renderer_env and tensor is not None and resolution[0]>0 and resolution[1]>0: + error_msg, img = safe_exec(renderer_env['renderer'].render, (title, tensor,resolution[0:2],resolution[2],shift,scale,input_tensors,target_tensor,labels), description='renderer definition') + if error_msg is not None: + print(error_msg, file=sys.stderr) + if img is None: + tc.send_msg(app_socket, 'ImageError') + else: + tc.send_msg(app_socket, 'ImageData', tc.encode_image(img)) + + if msg_type == 'Exit': + break