From 1804e9689b6e6630602534e306972cda1a248f89 Mon Sep 17 00:00:00 2001 From: Robin Lobel Date: Mon, 2 Jan 2023 12:38:29 +0100 Subject: [PATCH] 0.9.14 --- torchstudio/datasetanalyze.py | 28 ++-- torchstudio/datasetload.py | 74 +++++++---- torchstudio/datasets/genericloader.py | 168 +++++++++++++++++++++++ torchstudio/metrics/accuracy.py | 2 +- torchstudio/metricsplot.py | 29 +++- torchstudio/modelbuild.py | 2 +- torchstudio/modeltrain.py | 184 +++++++++++++++++--------- torchstudio/parametersplot.py | 16 ++- torchstudio/pythoncheck.py | 18 ++- torchstudio/pythoninstall.cmd | 118 ++++++----------- torchstudio/pythoninstall.py | 9 ++ torchstudio/pythonparse.py | 2 +- torchstudio/renderers/bitmap.py | 21 ++- torchstudio/renderers/signal.py | 2 +- torchstudio/renderers/spectrogram.py | 21 ++- torchstudio/renderers/volume.py | 21 ++- torchstudio/sshtunnel.py | 38 ++++-- torchstudio/tcpcodec.py | 28 ++-- torchstudio/tensorrender.py | 2 + 19 files changed, 535 insertions(+), 248 deletions(-) create mode 100644 torchstudio/datasets/genericloader.py diff --git a/torchstudio/datasetanalyze.py b/torchstudio/datasetanalyze.py index 6b9239e..22efa7c 100644 --- a/torchstudio/datasetanalyze.py +++ b/torchstudio/datasetanalyze.py @@ -27,23 +27,23 @@ print("Analyzing...\n", file=sys.stderr) analysis_server, address = tc.generate_server() + tc.send_msg(app_socket, 'ServerRequestingDataset', tc.encode_strings(address)) + + dataset_socket=tc.start_server(analysis_server) + + tc.send_msg(dataset_socket, 'RequestMetaInfos') if analyzer_env['analyzer'].train is None: - request_msg='AnalysisServerRequestingAllSamples' + request_msg='RequestAllSamples' elif analyzer_env['analyzer'].train==True: - request_msg='AnalysisServerRequestingTrainingSamples' + request_msg='RequestTrainingSamples' elif analyzer_env['analyzer'].train==False: - request_msg='AnalysisServerRequestingValidationSamples' - tc.send_msg(app_socket, request_msg, tc.encode_strings(address)) - dataset_socket=tc.start_server(analysis_server) + request_msg='RequestValidationSamples' + tc.send_msg(dataset_socket, request_msg, tc.encode_strings(address)) while True: dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket) - if dataset_msg_type == 'NumSamples': - num_samples=tc.decode_ints(dataset_msg_data)[0] - pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters - if dataset_msg_type == 'InputTensorsID': input_tensors_id=tc.decode_ints(dataset_msg_data) @@ -53,6 +53,10 @@ if dataset_msg_type == 'Labels': labels=tc.decode_strings(dataset_msg_data) + if dataset_msg_type == 'NumSamples': + num_samples=tc.decode_ints(dataset_msg_data)[0] + pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters + if dataset_msg_type == 'StartSending': error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition') if error_msg is not None: @@ -85,7 +89,7 @@ if dataset_msg_type == 'DoneSending': pbar.close() error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition') - tc.send_msg(dataset_socket, 'DoneReceiving') + tc.send_msg(dataset_socket, 'DisconnectFromWorkerServer') dataset_socket.close() analysis_server.close() if error_msg is not None: @@ -106,12 +110,14 @@ if msg_type == 'RequestAnalysisReport': resolution = tc.decode_ints(msg_data) - if 'analyzer' in analyzer_env: + if 'analyzer' in analyzer_env and resolution[0]>0 and resolution[1]>0: error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition') if error_msg is not None: print(error_msg, file=sys.stderr) if return_value is not None: tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value)) + else: + tc.send_msg(app_socket, 'ReportImage') if msg_type == 'Exit': break diff --git a/torchstudio/datasetload.py b/torchstudio/datasetload.py index ba9ab46..1654b3d 100644 --- a/torchstudio/datasetload.py +++ b/torchstudio/datasetload.py @@ -16,6 +16,7 @@ import time from collections.abc import Iterable from tqdm.auto import tqdm +import hashlib #monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error import ssl @@ -207,9 +208,7 @@ def __getitem__(self, id): if msg_type == 'OutputTensorsID': output_tensors_id = tc.decode_ints(msg_data) - if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples': - train_set=True if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendAllSamples' else False - valid_set=True if msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples' else False + if msg_type == 'ConnectToWorkerServer': name, sshaddress, sshport, username, password, keydata, address, port = tc.decode_strings(msg_data) port=int(port) @@ -241,30 +240,49 @@ def __getitem__(self, id): try: worker_socket = tc.connect((address,port),timeout=10) - num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0) - tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples)) - tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id)) - tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id)) - tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes)) - - tc.send_msg(worker_socket, 'StartSending') - with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar: - if train_set: - meta_dataset.train() - for i in range(len(meta_dataset)): - tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i])) - pbar.update(1) - if valid_set: - meta_dataset.valid() - for i in range(len(meta_dataset)): - tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i])) - pbar.update(1) - - tc.send_msg(worker_socket, 'DoneSending') - train_msg_type, train_msg_data = tc.recv_msg(worker_socket) - if train_msg_type == 'DoneReceiving': - worker_socket.close() - print('Samples transfer to '+name+' completed') + while True: + worker_msg_type, worker_msg_data = tc.recv_msg(worker_socket) + + if worker_msg_type == 'RequestMetaInfos': + tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id)) + tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id)) + tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes)) + + if worker_msg_type == 'RequestHash': + dataset_hash = hashlib.md5() + dataset_hash.update(int(len(meta_dataset.train())).to_bytes(4, 'little')) + if len(meta_dataset)>0: + dataset_hash.update(tc.encode_torch_tensors(meta_dataset[0])) + dataset_hash.update(int(len(meta_dataset.valid())).to_bytes(4, 'little')) + if len(meta_dataset)>0: + dataset_hash.update(tc.encode_torch_tensors(meta_dataset[0])) + tc.send_msg(worker_socket, 'DatasetHash', dataset_hash.digest()) + + if worker_msg_type == 'RequestTrainingSamples' or worker_msg_type == 'RequestValidationSamples' or worker_msg_type == 'RequestAllSamples': + train_set=True if worker_msg_type == 'RequestTrainingSamples' or worker_msg_type == 'RequestAllSamples' else False + valid_set=True if worker_msg_type == 'RequestValidationSamples' or worker_msg_type == 'RequestAllSamples' else False + num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0) + tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples)) + + tc.send_msg(worker_socket, 'StartSending') + with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar: + if train_set: + meta_dataset.train() + for i in range(len(meta_dataset)): + tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i])) + pbar.update(1) + if valid_set: + meta_dataset.valid() + for i in range(len(meta_dataset)): + tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i])) + pbar.update(1) + + tc.send_msg(worker_socket, 'DoneSending') + + if worker_msg_type == 'DisconnectFromWorkerServer': + worker_socket.close() + print('Samples transfer to '+name+' completed') + break except: if sshaddress and sshport and username: @@ -277,7 +295,7 @@ def __getitem__(self, id): except: pass try: - sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong + sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DisconnectFromWorkerServer ping pong except: pass diff --git a/torchstudio/datasets/genericloader.py b/torchstudio/datasets/genericloader.py new file mode 100644 index 0000000..90c53f8 --- /dev/null +++ b/torchstudio/datasets/genericloader.py @@ -0,0 +1,168 @@ +import os +import torch +from torch.utils.data import Dataset +from PIL import Image +import torchvision +import torchaudio +import numpy as np +import sys + +class GenericLoader(Dataset): + """A generic dataset loader. + Suitable for classification, segmentation and regression datasets. + Supports image, audio, and numpy array files. + + Args: + path (str): + path to the dataset + + classification (bool): + True: classification dataset (single class prediction: class1, class2, ...) + False: segmentation or regression dataset (multiple components: input, target, ...) + + separator (str or None): + '/': folders will be used to determine classes or components + (classes: class1/1.ext, class1/2.ext, class2/1.ext, class2/2.ext, ...) + (components: inputs/1.ext, inputs/2.ext, targets/1.ext, targets/2.ext, ...) + + '_' or other separator: file name parts will be used to determine classes or components + (classes: class1_1.ext, class1_2.ext, class2_1.ext, class2_2.ext, ...) + (components: 1_input.ext, 1_output.ext, 2_input.ext, 2_output.ext, ...) + + '' or None: file names or their content will be used to determine components + (one sample per folder: 1/input.ext, 1/output.ext, 2/input.ext, 2/output.ext, ...) + (samples in one folder: 1.ext, 2.ext, ...) + + extensions (str): + file extension to filters (such as: .jpg, .jpeg, .png, .mp3, .wav, .npy, .npz) + + transforms (list): + list of transforms to apply to the different components of each sample (use None is some components need no transform) + (ie: [torchvision.transforms.Compose([transforms.Resize(64)]), torchaudio.transforms.Spectrogram()]) + """ + + def __init__(self, path:str='', classification:bool=True, separator:str='/', extensions:str='.jpg, .jpeg, .png, .mp3, .wav, .npy, .npz', transforms=[]): + exts = tuple(extensions.replace(' ','').split(',')) + paths = [] + self.samples = [] + self.classes = [] + self.transforms = transforms + if not os.path.exists(path): + print("Path not found.", file=sys.stderr) + return + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith(exts): + paths.append(os.path.join(root, file).replace('\\','/')) + paths=sorted(paths) + if not paths: + print("No files found.", file=sys.stderr) + return + self.classification=classification + if classification: + if separator == '/': + for path in paths: + class_name=path.split('/')[-2] + if class_name not in self.classes: + self.classes.append(class_name) + self.samples.append([path, self.classes.index(class_name)]) + elif separator: + for path in paths: + class_name = path.split('/')[-1].split(separator)[0] + if class_name not in self.classes: + self.classes.append(class_name) + self.samples.append([path, self.classes.index(class_name)]) + else: + print("You need a separator with classication datasets", file=sys.stderr) + return + else: + samples_index = dict() + if separator == '/': + for path in paths: + components_name=path.split('/')[-2] + sample_name = path.split('/')[-1].split('.')[-2] + if sample_name not in samples_index: + samples_index[sample_name] = len(self.samples) + self.samples.append([]) + self.samples[samples_index[sample_name]].append(path) + elif separator: + for path in paths: + components_name = path.split('.')[-2].split(separator)[-1] + sample_name = path.split('/')[-1].split(separator)[0] + if sample_name not in samples_index: + samples_index[sample_name] = len(self.samples) + self.samples.append([]) + self.samples[samples_index[sample_name]].append(path) + else: + single_folder=True + file_root=path[:path.rfind("/")] + for path in paths: + if not path.startswith(file_root): + single_folder=False + break + if single_folder: + for path in paths: + sample_name = path.split('/')[-1].split('.')[-2] + if sample_name not in samples_index: + samples_index[sample_name] = len(self.samples) + self.samples.append([]) + self.samples[samples_index[sample_name]].append(path) + else: + for path in paths: + components_name = path.split('/')[-1].split('.')[-2] + sample_name = path.split('/')[-2] + if sample_name not in samples_index: + samples_index[sample_name] = len(self.samples) + self.samples.append([]) + self.samples[samples_index[sample_name]].append(path) + + def to_tensors(self, path:str): + if path.endswith('.jpg') or path.endswith('.jpeg') or path.endswith('.png'): + img=Image.open(path) + if img.getpalette(): + return [torch.from_numpy(np.array(img, dtype=np.uint8))] + else: + trans=torchvision.transforms.ToTensor() + return [trans(img)] + + if path.endswith('.mp3') or path.endswith('.wav'): + waveform, sample_rate = torchaudio.load(path) + return [waveform] + + if path.endswith('.npy') or path.endswith('.npz'): + arrays = np.load(path) + if type(arrays) == dict: + tensors = [] + for array in arrays: + tensors.append(torch.from_numpy(arrays[array])) + return tensors + else: + return [torch.from_numpy(arrays)] + + def __len__(self): + return len(self.samples) + + def __getitem__(self, id): + """ + Returns: + A tuple of tensors. + """ + + if id < 0 or id >= len(self): + raise IndexError + + components = [] + for component in self.samples[id]: + if type(component) is str: + components.extend(self.to_tensors(component)) + else: + components.extend([torch.tensor(component)]) + + if self.transforms: + if type(self.transforms) is not list and type(self.transforms) is not tuple: + self.transforms = [self.transforms] + for i, transform in enumerate(self.transforms): + if i < len(components) and transform is not None: + components[i] = transform(components[i]) + + return tuple(components) diff --git a/torchstudio/metrics/accuracy.py b/torchstudio/metrics/accuracy.py index 228981c..59c2661 100644 --- a/torchstudio/metrics/accuracy.py +++ b/torchstudio/metrics/accuracy.py @@ -9,7 +9,7 @@ class Accuracy(Metric): threshold: error threshold below which predictions are considered accurate (not used in multiclass) normalize: if set to True, normalize predictions with sigmoid or softmax before calculating the accuracy """ - def __init__(self, threshold: float = 0.1, normalize: bool = False): + def __init__(self, threshold: float = 0.01, normalize: bool = False): self.threshold = threshold self.normalize = normalize self.reset() diff --git a/torchstudio/metricsplot.py b/torchstudio/metricsplot.py index 6c29630..7f4eb01 100644 --- a/torchstudio/metricsplot.py +++ b/torchstudio/metricsplot.py @@ -10,7 +10,8 @@ def plot_metrics(prefix, size, dpi, samples=100, labels=[], loss=[], loss_colors=[], loss_shift=(0,0), loss_scale=(1,1), - metric=[], metric_colors=[], metric_shift=(0,0), metric_scale=(1,1)): + metric=[], metric_colors=[], metric_shift=(0,0), metric_scale=(1,1), + best=[]): """Metrics Plot Usage: @@ -59,12 +60,21 @@ def plot_metrics(prefix, size, dpi, samples=100, labels=[], loss_ymax=loss_ymax/loss_scale[1] ax1.axis(xmin=loss_xmin,xmax=loss_xmax,ymin=loss_ymin,ymax=loss_ymax) + def forward(x): + return x**(1/2) + def inverse(x): + return x**2 + ax1.set_yscale('function', functions=(forward, inverse)) ax1.spines['top'].set_visible(False) ax1.spines['right'].set_visible(False) ax1.spines['left'].set_color('#707070') ax1.spines['bottom'].set_color('#707070') for i in range(len(loss)): - ax1.plot(loss[i],label=str(i) if i>=len(labels) else labels[i],color=loss_colors[i%len(loss_colors)]) + ax1.plot(loss[i], label=str(i) if i>=len(labels) else labels[i], color=loss_colors[i%len(loss_colors)]) + for i in range(len(best)): + if best[i]>=0: + ax1.plot(best[i], loss[i][best[i]], color=loss_colors[i%len(loss_colors)], marker='o', markersize=3) + ax1.plot(best[i], loss[i][best[i]], color=(1, 1, 1, 0.5), marker='o', markersize=3) if labels and loss and loss[0]: ax1.legend(bbox_to_anchor=(1, 1), loc='upper right', ncol=1, prop={'size': 8}) ax1.grid(color = '#303030') @@ -108,7 +118,11 @@ def plot_metrics(prefix, size, dpi, samples=100, labels=[], ax2.spines['left'].set_color('#707070') ax2.spines['bottom'].set_color('#707070') for i in range(len(metric)): - ax2.plot(metric[i],color=metric_colors[i%len(metric_colors)]) + ax2.plot(metric[i], color=metric_colors[i%len(metric_colors)]) + for i in range(len(best)): + if best[i]>=0: + ax2.plot(best[i], metric[i][best[i]], color=metric_colors[i%len(loss_colors)], marker='o', markersize=3) + ax2.plot(best[i], metric[i][best[i]], color=(1, 1, 1, 0.5), marker='o', markersize=3) ax2.grid(color = '#303030') ax2.xaxis.set_major_locator(MaxNLocator(nbins='auto', integer=True)) @@ -137,6 +151,8 @@ def plot_metrics(prefix, size, dpi, samples=100, labels=[], metric_shift = (0,0) metric_scale = (1,1) +best=[] + app_socket = tc.connect() while True: msg_type, msg_data = tc.recv_msg(app_socket) @@ -176,10 +192,15 @@ def plot_metrics(prefix, size, dpi, samples=100, labels=[], if msg_type == 'SetMetricScale': metric_scale = tc.decode_floats(msg_data) + if msg_type == 'SetBest': + best = tc.decode_ints(msg_data) + if msg_type == 'Render': if resolution[0]>0 and resolution[1]>0: - img=plot_metrics(prefix,resolution[0:2],resolution[2],samples,labels,loss,loss_colors,loss_shift,loss_scale,metric,metric_colors,metric_shift,metric_scale) + img=plot_metrics(prefix,resolution[0:2],resolution[2],samples,labels,loss,loss_colors,loss_shift,loss_scale,metric,metric_colors,metric_shift,metric_scale,best) tc.send_msg(app_socket, 'ImageData', tc.encode_image(img)) + else: + tc.send_msg(app_socket, 'ImageError') if msg_type == 'Exit': break diff --git a/torchstudio/modelbuild.py b/torchstudio/modelbuild.py index 5f04c3e..4cf88bf 100644 --- a/torchstudio/modelbuild.py +++ b/torchstudio/modelbuild.py @@ -238,7 +238,7 @@ def level_trace(root): for tensor in output_tensors: metric.append("Accuracy") - tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([128,0,100,1,1])) + tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([64,0,100,20])) tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['Adam','Step'])) if msg_type == 'Exit': diff --git a/torchstudio/modeltrain.py b/torchstudio/modeltrain.py index 27b93f1..87cb84f 100644 --- a/torchstudio/modeltrain.py +++ b/torchstudio/modeltrain.py @@ -13,26 +13,51 @@ import os import sys import io -import tempfile from tqdm.auto import tqdm from collections.abc import Iterable +import threading class CachedDataset(Dataset): - def __init__(self, disk_cache=False): - self.reset(disk_cache) - - def add_sample(self, data): - if self.disk_cache: - file=tempfile.TemporaryFile(prefix='torchstudio.'+str(len(self.index))+'.') #guaranteed to be deleted on win/mac/linux: https://bugs.python.org/issue4928 - file.write(data) - self.index.append(file) - else: - self.index.append(tc.decode_torch_tensors(data)) - - def reset(self, disk_cache=False): + def __init__(self, train=True, hash=None): self.index = [] - self.disk_cache=disk_cache + self.cache=None + if hash: + self.filename='cache/dataset-'+('training' if train==True else 'validation') + if os.path.exists(self.filename): + self.cache = open(self.filename, 'rb') + cached_hash=self.cache.read(16) + if cached_hash==hash: + size=self.cache.read(4) + while size: + data=self.cache.read(int.from_bytes(size, 'little')) + self.index.append(tc.decode_torch_tensors(data)) + size=self.cache.read(4) + self.cache.close() + return + else: + self.cache.close() + os.remove(self.filename) + if os.path.exists(self.filename+'.tmp'): + os.remove(self.filename+'.tmp') + if not os.path.exists('cache'): + os.mkdir('cache') + self.cache = open(self.filename+'.tmp', 'wb') + self.cache.write(hash) + + def add_sample(self, data=None): + if data: + if self.cache: + self.cache.write(len(data).to_bytes(4, 'little')) + self.cache.write(data) + self.index.append(tc.decode_torch_tensors(data)) + else: + if self.cache: + self.cache.close() + try: + os.rename(self.filename+'.tmp', self.filename) + except: + pass def __len__(self): return len(self.index) @@ -40,14 +65,7 @@ def __len__(self): def __getitem__(self, id): if id<0 or id>=len(self): raise IndexError - - if self.disk_cache: - file=self.index[id] - file.seek(0) - sample=tc.decode_torch_tensors(file.read()) - else: - sample=self.index[id] - return sample + return self.index[id] def deepcopy_cpu(value): if isinstance(value, torch.Tensor): @@ -62,11 +80,14 @@ def deepcopy_cpu(value): modules_valid=True -train_dataset = CachedDataset() -valid_dataset = CachedDataset() +train_dataset = CachedDataset(True) +valid_dataset = CachedDataset(False) train_bar = None model = None +sender_thread = None + +cache = None app_socket = tc.connect() print("Training script connected\n", file=sys.stderr) @@ -83,6 +104,10 @@ def deepcopy_cpu(value): print("Setting mode...\n", file=sys.stderr) mode=tc.decode_strings(msg_data)[0] + if msg_type == 'SetCache': + print("Setting cache...\n", file=sys.stderr) + cache = True if tc.decode_ints(msg_data)[0]==1 else False + if msg_type == 'SetTorchScriptModel' and modules_valid: if msg_data: print("Setting torchscript model...\n", file=sys.stderr) @@ -156,10 +181,11 @@ def deepcopy_cpu(value): scheduler = scheduler_env['scheduler'] if msg_type == 'SetHyperParametersValues' and modules_valid: #set other hyperparameters values - batch_size, shuffle, epochs, early_stop, restore_best = tc.decode_ints(msg_data) + batch_size, shuffle, epochs, early_stop = tc.decode_ints(msg_data) shuffle=True if shuffle==1 else False - early_stop=True if early_stop==1 else False - restore_best=True if restore_best==1 else False + + if msg_type == 'SetBestMetrics': + best_train_loss, best_valid_loss, best_train_metric, best_valid_metric = tc.decode_floats(msg_data) if msg_type == 'StartTrainingServer' and modules_valid: print("Caching...\n", file=sys.stderr) @@ -189,24 +215,41 @@ def deepcopy_cpu(value): reverse_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ReverseTunnel, 'localhost', 0, 'localhost', int(address[1])) address[1]=str(reverse_tunnel.lport) - tc.send_msg(app_socket, 'TrainingServerRequestingAllSamples', tc.encode_strings(address)) + tc.send_msg(app_socket, 'ServerRequestingDataset', tc.encode_strings(address)) + dataset_socket=tc.start_server(training_server) - train_dataset.reset() - valid_dataset.reset() + + tc.send_msg(dataset_socket, 'RequestMetaInfos') + tc.send_msg(dataset_socket, 'RequestHash') while True: dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket) - if dataset_msg_type == 'NumSamples': - num_samples=tc.decode_ints(dataset_msg_data)[0] - pbar=tqdm(total=num_samples, desc='Caching...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters - if dataset_msg_type == 'InputTensorsID' and modules_valid: input_tensors_id = tc.decode_ints(dataset_msg_data) if dataset_msg_type == 'OutputTensorsID' and modules_valid: output_tensors_id = tc.decode_ints(dataset_msg_data) + if dataset_msg_type == 'DatasetHash': + train_dataset=CachedDataset(True) + valid_dataset=CachedDataset(False) + if cache: + train_dataset=CachedDataset(True, dataset_msg_data) + valid_dataset=CachedDataset(False, dataset_msg_data) + if len(train_dataset)==0 and len(valid_dataset)==0: + tc.send_msg(dataset_socket, 'RequestAllSamples', tc.encode_strings(address)) + elif len(train_dataset)==0: + tc.send_msg(dataset_socket, 'RequestTrainingSamples', tc.encode_strings(address)) + elif len(valid_dataset)==0: + tc.send_msg(dataset_socket, 'RequestValidationSamples', tc.encode_strings(address)) + else: + break + + if dataset_msg_type == 'NumSamples': + num_samples=tc.decode_ints(dataset_msg_data)[0] + pbar=tqdm(total=num_samples, desc='Caching...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters + if dataset_msg_type == 'TrainingSample': train_dataset.add_sample(dataset_msg_data) pbar.update(1) @@ -216,14 +259,17 @@ def deepcopy_cpu(value): pbar.update(1) if dataset_msg_type == 'DoneSending': + train_dataset.add_sample() + valid_dataset.add_sample() pbar.close() - tc.send_msg(dataset_socket, 'DoneReceiving') - dataset_socket.close() - training_server.close() - if sshaddress and sshport and username: - sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong break + tc.send_msg(dataset_socket, 'DisconnectFromWorkerServer') + dataset_socket.close() + training_server.close() + if sshaddress and sshport and username: + sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DisconnectFromWorkerServer ping pong + train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory) valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size, shuffle=False, pin_memory=pin_memory) @@ -280,10 +326,10 @@ def deepcopy_cpu(value): metric.update(output, target) train_loss = train_loss/len(train_dataset) - train_metrics = 0 + train_metric = 0 for metric in metrics: - train_metrics = train_metrics+metric.compute().item() - train_metrics/=len(metrics) + train_metric = train_metric+metric.compute().item() + train_metric/=len(metrics) scheduler.step() #validation @@ -307,26 +353,41 @@ def deepcopy_cpu(value): metric.update(output, target) valid_loss = valid_loss/len(valid_dataset) - valid_metrics = 0 + valid_metric = 0 for metric in metrics: - valid_metrics = valid_metrics+metric.compute().item() - valid_metrics/=len(metrics) + valid_metric = valid_metric+metric.compute().item() + valid_metric/=len(metrics) + + #threaded (async) results sending, so to send last metrics and best weights when available while calculating new ones + if sender_thread: + sender_thread.join() + + metrics_values=[train_loss, valid_loss, train_metric, valid_metric] + + model_state_buffer=None + optimizer_state_buffer=None - tc.send_msg(app_socket, 'TrainingLoss', tc.encode_floats(train_loss)) - tc.send_msg(app_socket, 'ValidationLoss', tc.encode_floats(valid_loss)) - tc.send_msg(app_socket, 'TrainingMetric', tc.encode_floats(train_metrics)) - tc.send_msg(app_socket, 'ValidationMetric', tc.encode_floats(valid_metrics)) + if valid_metric>best_valid_metric or (valid_metric==best_valid_metric and valid_loss1 else .5 for j in range(len(param_values[i]))]) ax.set_yticklabels(param_values[i]) - #first parameter is the model name, keep the set_ticks + #first parameter is the model name, keep the set_ticks and color names axes[0].yaxis.set_tick_params(width=1) + for i, tick in enumerate(axes[0].yaxis.get_major_ticks()): + tick.tick1line.set_markeredgecolor(colors[len(colors)-i-1]) + tick.tick1line.set_markeredgewidth(4) + + #last parameter is the metric, let the colorbar do the metric axes[-1].yaxis.set_ticks_position('none') axes[-1].set_yticklabels([]) @@ -136,6 +142,7 @@ def plot_parameters(size, dpi, resolution = (256,256, 96) parameters=[] +colors=[] values=[] order=[] @@ -153,6 +160,9 @@ def plot_parameters(size, dpi, if msg_type == 'SetParameters': parameters=tc.decode_strings(msg_data) + if msg_type == 'SetColors': + colors=tc.decode_strings(msg_data) + if msg_type == 'ClearValues': values = [] if msg_type == 'AppendValues': @@ -163,8 +173,10 @@ def plot_parameters(size, dpi, if msg_type == 'Render': if resolution[0]>0 and resolution[1]>0: - img=plot_parameters(resolution[0:2],resolution[2],parameters,values,order) + img=plot_parameters(resolution[0:2],resolution[2],parameters,colors,values,order) tc.send_msg(app_socket, 'ImageData', tc.encode_image(img)) + else: + tc.send_msg(app_socket, 'ImageError') if msg_type == 'Exit': break diff --git a/torchstudio/pythoncheck.py b/torchstudio/pythoncheck.py index aa84cac..7253036 100644 --- a/torchstudio/pythoncheck.py +++ b/torchstudio/pythoncheck.py @@ -76,8 +76,17 @@ print("Listing devices...\n", file=sys.stderr) devices = {} - devices['cpu'] = {'name': 'CPU', 'pin_memory': False, 'modes': ['FP32']} + devices['cpu'] = {'name': 'CPU', 'modes': ['FP32']} + + cuda_names = {} for i in range(torch.cuda.device_count()): + name=torch.cuda.get_device_name(i) + if name in cuda_names: + cuda_names[name]+=1 + name+=" "+str(cuda_names[name]) + else: + cuda_names[name]=1 + modes = ['FP32'] #same as torch.cuda.is_bf16_supported() but compatible with PyTorch<1.10, and not limited to current cuda device only cu_vers = torch.version.cuda @@ -90,10 +99,13 @@ modes+=['TF32','FP16','BF16'] if compute_capability==7: #RTX 2000 modes+=['FP16'] - devices['cuda:'+str(i)] = {'name': torch.cuda.get_device_name(i), 'pin_memory': True, 'modes': modes} + + devices['cuda:'+str(i)] = {'name': name, 'modes': modes} + if pytorch_version>=(1,12,0): if torch.backends.mps.is_available(): - devices['mps'] = {'name': 'Metal', 'pin_memory': False, 'modes': ['FP32']} + devices['mps'] = {'name': 'Metal', 'modes': ['FP32']} + #other possible devices: #'hpu' (https://docs.habana.ai/en/latest/PyTorch_User_Guide/PyTorch_User_Guide.html) #'dml' (https://docs.microsoft.com/en-us/windows/ai/directml/gpu-pytorch-windows) diff --git a/torchstudio/pythoninstall.cmd b/torchstudio/pythoninstall.cmd index e2ce990..a643817 100644 --- a/torchstudio/pythoninstall.cmd +++ b/torchstudio/pythoninstall.cmd @@ -30,7 +30,6 @@ done if [ ! -z "$uninstall" ]; then echo "Uninstalling python environment..." rm -f *.sh - rm -f *.tmp rm -f -r "$pythonpath" if [ $? != 0 ]; then echo "" 1>&2 @@ -61,20 +60,14 @@ fi echo "" if [[ $OSTYPE == "linux"* ]]; then file=Miniconda3-latest-Linux-x86_64.sh - rm -f "$file.tmp" - if [ -f "$file" ]; then - echo "Python installer ($file) already downloaded" - else - echo "Downloading Python installer ($file)..." - wget --show-progress --progress=bar:force:noscroll --no-check-certificate https://repo.anaconda.com/miniconda/$file -O "$file.tmp" - if [ $? != 0 ]; then - rm -f "$file.tmp" - echo "" 1>&2 - echo "Error while downloading. Make sure port 80 is open." 1>&2 - exit 1 - else - mv "$file.tmp" "$file" - fi + rm -f "$file" + echo "Downloading Python installer ($file)..." + wget --show-progress --progress=bar:force:noscroll --no-check-certificate https://repo.anaconda.com/miniconda/$file -O "$file" + if [ $? != 0 ]; then + rm -f "$file" + echo "" 1>&2 + echo "Error while downloading. Make sure port 80 is open." 1>&2 + exit 1 fi elif [[ $OSTYPE == "darwin"* ]]; then if [ "$(uname -m)" == "arm64" ]; then @@ -82,42 +75,29 @@ elif [[ $OSTYPE == "darwin"* ]]; then else file=Miniconda3-latest-MacOSX-x86_64.sh fi - rm -f "$file.tmp" - if [ -f "$file" ]; then - echo "Python installer $file already downloaded" - else - echo "Downloading Python installer $file..." - curl --insecure https://repo.anaconda.com/miniconda/$file -o "$file.tmp" - if [ $? != 0 ]; then - rm -f "$file.tmp" - echo "" 1>&2 - echo "Error while downloading. Make sure port 80 is open." 1>&2 - exit 1 - else - mv "$file.tmp" "$file" - fi + rm -f "$file" + echo "Downloading Python installer $file..." + curl --insecure https://repo.anaconda.com/miniconda/$file -o "$file" + if [ $? != 0 ]; then + rm -f "$file" + echo "" 1>&2 + echo "Error while downloading. Make sure port 80 is open." 1>&2 + exit 1 fi fi echo "" -if [ -f "python.tmp" ]; then - rm -f python.tmp +rm -f -r "$pythonpath" +echo "Installing Python in $pythonpath..." +bash "$(pwd)/$file" -b -f -p "$pythonpath" +if [ $? != 0 ]; then + rm -f "$file" rm -f -r "$pythonpath" + echo "" 1>&2 + echo "Error while installing. Make sure you have write permissions." 1>&2 + exit 1 fi -if [ -d "$pythonpath" ]; then - echo "Python already installed in $pythonpath" -else - echo "Installing Python in $pythonpath..." - echo "" > python.tmp - bash "$(pwd)/$file" -b -f -p "$pythonpath" - rm -f python.tmp - if [ $? != 0 ]; then - rm -f -r "$pythonpath" - echo "" 1>&2 - echo "Error while installing. Make sure you have write permissions." 1>&2 - exit 1 - fi -fi +rm -f "$file" PATH="$PATH;$pythonpath/bin" "$pythonpath/bin/python" -u -B -X utf8 -m torchstudio.pythoninstall --channel $channel $cuda $packages @@ -169,7 +149,6 @@ goto args if DEFINED uninstall ( echo Uninstalling python environment... del *.exe 2>nul - del *.tmp 2>nul rmdir /s /q "%pythonpath%" 2>nul if ERRORLEVEL 1 ( echo. 1>&2 @@ -191,41 +170,28 @@ if DEFINED cuda ( echo. set file=Miniconda3-latest-Windows-x86_64.exe -del %file%.tmp 2>nul -if EXIST "%file%" ( - echo Python installer %file% already downloaded -) else ( - echo Downloading Python installer %file%... - curl --insecure https://repo.anaconda.com/miniconda/%file% -o %file%.tmp - if ERRORLEVEL 1 ( - del %file%.tmp 2>nul - echo. 1>&2 - echo Error while downloading. Make sure port 80 is open. 1>&2 - exit /B 1 - ) else ( - ren %file%.tmp %file% - ) +del %file% 2>nul +echo Downloading Python installer %file%... +curl --insecure https://repo.anaconda.com/miniconda/%file% -o %file% +if ERRORLEVEL 1 ( + del %file% 2>nul + echo. 1>&2 + echo Error while downloading. Make sure port 80 is open. 1>&2 + exit /B 1 ) echo. -if EXIST "python.tmp" ( - del python.tmp 2>nul +rmdir /s /q "%pythonpath%" 2>nul +echo Installing Python in %pythonpath%... +%file% /S /D=%pythonpath% +if ERRORLEVEL 1 ( + del %file% 2>nul rmdir /s /q "%pythonpath%" 2>nul + echo. 1>&2 + echo Error while installing. Make sure you have write permissions. 1>&2 + exit /B 1 ) -if EXIST "%pythonpath%" ( - echo Python already installed in %pythonpath% -) else ( - echo Installing Python in %pythonpath%... - echo. > python.tmp - %file% /S /D=%pythonpath% - del python.tmp 2>nul - if ERRORLEVEL 1 ( - rmdir /s /q "%pythonpath%" 2>nul - echo. 1>&2 - echo Error while installing. Make sure you have write permissions. 1>&2 - exit /B 1 - ) -) +del %file% 2>nul set PATH=%PATH%;%pythonpath%;%pythonpath%\Library\mingw-w64\bin;%pythonpath%\Library\bin;%pythonpath%\bin "%pythonpath%\python" -u -B -X utf8 -m torchstudio.pythoninstall --channel %channel% %cuda% %packages% diff --git a/torchstudio/pythoninstall.py b/torchstudio/pythoninstall.py index 0e688c7..e082121 100644 --- a/torchstudio/pythoninstall.py +++ b/torchstudio/pythoninstall.py @@ -14,6 +14,15 @@ import conda.cli.python_api as Conda +#increase rows (from default 20 when no terminal is found) to display all parallel packages downloads at once +from tqdm import tqdm +init_source=tqdm.__init__ +def init_patch(self, **kwargs): + kwargs['ncols']=80 + kwargs['nrows']=80 + init_source(self, **kwargs) +tqdm.__init__=init_patch + if not args.package: #https://edcarp.github.io/introduction-to-conda-for-data-scientists/03-using-packages-and-channels/index.html#alternative-syntax-for-installing-packages-from-specific-channels conda_install=f"{args.channel}::pytorch {args.channel}::torchvision {args.channel}::torchaudio {args.channel}::torchtext" diff --git a/torchstudio/pythonparse.py b/torchstudio/pythonparse.py index 9a64fe2..c1d6630 100644 --- a/torchstudio/pythonparse.py +++ b/torchstudio/pythonparse.py @@ -353,7 +353,7 @@ def __init__(self, optimizer): def scan_folder(path): path=path.replace('.','/') codes=[] - for filename in listdir(path): + for filename in sorted(listdir(path)): if isfile(join(path, filename)): with open(join(path, filename), "r") as file: codes.append(file.read()) diff --git a/torchstudio/renderers/bitmap.py b/torchstudio/renderers/bitmap.py index 26d705e..91cf9a2 100644 --- a/torchstudio/renderers/bitmap.py +++ b/torchstudio/renderers/bitmap.py @@ -18,11 +18,13 @@ class Bitmap(Renderer): Args: colormap (str): Colormap to be used for single channel bitmaps. Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis' + colors: List of colors for each channel for multi channels bitmaps (looped if necessary) rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise) """ - def __init__(self, colormap='inferno', rotate=0): + def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0): super().__init__() self.colormap=colormap + self.colors=colors self.rotate=rotate def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]): @@ -32,21 +34,18 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp return None #flatten - if tensor.shape[0]==2: #2 channels, pad with a third channel - zero = np.zeros((1,tensor.shape[1], tensor.shape[2])) - tensor = np.concatenate((tensor,zero),0) - if tensor.shape[0]>3: #more than 3 channels, add additional channels into the first 3 - for i in range(3,tensor.shape[0]): - tensor[[i%3]]+=tensor[[i]] - if i%6>=3: #add R G B R G B to RG GB BR R G B - tensor[[(i+1)%3]]+=tensor[[i]] - tensor=tensor[[0,1,2]] + if tensor.shape[0]>1: + zero = np.zeros((3,tensor.shape[1], tensor.shape[2])) + for i in range(tensor.shape[0]): + color=np.array(mpl.colors.to_rgb(self.colors[i%len(self.colors)])).reshape(3,1,1) + zero+=tensor[[i]]*color + tensor=zero if self.rotate>0: tensor=np.rot90(tensor, self.rotate, axes=(1, 2)) #apply brightness, gamma and conversion to uint8, then transform CHW to HWC - tensor = np.multiply(np.clip(np.power(tensor*scale[0],1/scale[3]),0,1),255).astype(np.uint8) + tensor = np.multiply(np.clip(np.power(np.clip(tensor*scale[0],0,1),1/scale[3]),0,1),255).astype(np.uint8) tensor = tensor.transpose((1, 2, 0)) #set up matplotlib renderer, style, figure and axis diff --git a/torchstudio/renderers/signal.py b/torchstudio/renderers/signal.py index 661959f..843d92e 100644 --- a/torchstudio/renderers/signal.py +++ b/torchstudio/renderers/signal.py @@ -21,7 +21,7 @@ class Signal(Renderer): 'fixed': use the min and max values defined by the user 'fit': use the min and max values of the signal 'auto': fit when the values are beyond the user-defined min and max values - colors: List of colors for each channel + colors: List of colors for each channel (looped if necessary) grid: Display grid legend: Display legend with more than one channel """ diff --git a/torchstudio/renderers/spectrogram.py b/torchstudio/renderers/spectrogram.py index 6f73586..b7cdc07 100644 --- a/torchstudio/renderers/spectrogram.py +++ b/torchstudio/renderers/spectrogram.py @@ -16,13 +16,15 @@ class Spectrogram(Renderer): Ctrl/Cmd Scroll: adjust gamma Args: - colormap (str): Colormap to be used for single channel bitmaps. + colormap (str): Colormap to be used for single channel spectrograms. Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis' + colors: List of colors for each channel for multi channels spectrograms (looped if necessary) rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise) """ - def __init__(self, colormap='inferno', rotate=0): + def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0): super().__init__() self.colormap=colormap + self.colors=colors self.rotate=rotate def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]): @@ -42,15 +44,12 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp tensor=np.absolute(tensor[::2]+1j*tensor[1::2]) #flatten - if tensor.shape[0]==2: #2 channels, pad with a third channel - zero = np.zeros((1,tensor.shape[1], tensor.shape[2])) - tensor = np.concatenate((tensor,zero),0) - if tensor.shape[0]>3: #more than 3 channels, add additional channels into the first 3 - for i in range(3,tensor.shape[0]): - tensor[[i%3]]+=tensor[[i]] - if i%6>=3: #add R G B R G B to RG GB BR R G B - tensor[[(i+1)%3]]+=tensor[[i]] - tensor=tensor[[0,1,2]] + if tensor.shape[0]>1: + zero = np.zeros((3,tensor.shape[1], tensor.shape[2])) + for i in range(tensor.shape[0]): + color=np.array(mpl.colors.to_rgb(self.colors[i%len(self.colors)])).reshape(3,1,1) + zero+=tensor[[i]]*color + tensor=zero if self.rotate>0: tensor=np.rot90(tensor, self.rotate, axes=(1, 2)) diff --git a/torchstudio/renderers/volume.py b/torchstudio/renderers/volume.py index 1aa83f6..849c673 100644 --- a/torchstudio/renderers/volume.py +++ b/torchstudio/renderers/volume.py @@ -19,11 +19,13 @@ class Volume(Renderer): Args: colormap (str): Colormap to be used for single channel volumes. Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis' + colors: List of colors for each channel for multi channels volumes (looped if necessary) rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise) """ - def __init__(self, colormap='inferno', rotate=0): + def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0): super().__init__() self.colormap=colormap + self.colors=colors self.rotate=rotate def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]): @@ -39,21 +41,18 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp tensor = tensor[:,depth] #flatten - if tensor.shape[0]==2: #2 channels, pad with a third channel - zero = np.zeros((1,tensor.shape[1], tensor.shape[2])) - tensor = np.concatenate((tensor,zero),0) - if tensor.shape[0]>3: #more than 3 channels, add additional channels into the first 3 - for i in range(3,tensor.shape[0]): - tensor[[i%3]]+=tensor[[i]] - if i%6>=3: #add R G B R G B to RG GB BR R G B - tensor[[(i+1)%3]]+=tensor[[i]] - tensor=tensor[[0,1,2]] + if tensor.shape[0]>1: + zero = np.zeros((3,tensor.shape[1], tensor.shape[2])) + for i in range(tensor.shape[0]): + color=np.array(mpl.colors.to_rgb(self.colors[i%len(self.colors)])).reshape(3,1,1) + zero+=tensor[[i]]*color + tensor=zero if self.rotate>0: tensor=np.rot90(tensor, self.rotate, axes=(1, 2)) #apply luminosity and conversion to uint8, then transform CHW to HWC - tensor = np.multiply(np.clip(np.power(tensor*scale[0],1/scale[3]),0,1),255).astype(np.uint8) + tensor = np.multiply(np.clip(np.power(np.clip(tensor*scale[0],0,1),1/scale[3]),0,1),255).astype(np.uint8) tensor = tensor.transpose((1, 2, 0)) #set up matplotlib renderer, style, figure and axis diff --git a/torchstudio/sshtunnel.py b/torchstudio/sshtunnel.py index 32244a1..ac2e7ec 100644 --- a/torchstudio/sshtunnel.py +++ b/torchstudio/sshtunnel.py @@ -62,7 +62,7 @@ def stop(self): def handler(self, rev_socket, origin, laddress): rev_handler = ReverseTunnelHandler(rev_socket, self.dhost, self.dport, self.lhost, self.lport) - rev_handler.setDaemon(True) + rev_handler.daemon=True rev_handler.start() self.handlers.append(rev_handler) @@ -246,7 +246,7 @@ def finish(self): parser.add_argument("--username", help="ssh server username", type=str, default=None) parser.add_argument("--password", help="ssh server password", type=str, default=None) parser.add_argument("--keyfile", help="ssh server key file", type=str, default=None) - parser.add_argument('--clean', help="clean all files", action='store_true', default=False) + parser.add_argument('--clean', help="cleaning level (0: cache, 1: environment, 2: all)", type=int, default=None) parser.add_argument("--command", help="command to execute or run python scripts", type=str, default=None) parser.add_argument("--script", help="python script to be launched on the server", type=str, default=None) parser.add_argument("--address", help="address to which the script must connect", type=str, default=None) @@ -263,12 +263,25 @@ def finish(self): print("Error: could not connect to remote server", file=sys.stderr) exit(1) - if args.clean: - print("Cleaning...", file=sys.stderr) - stdin, stdout, stderr = sshclient.exec_command('rm -r -f TorchStudio') - exit_status = stdout.channel.recv_exit_status() - stdin, stdout, stderr = sshclient.exec_command('rmdir /s /q TorchStudio') - exit_status = stdout.channel.recv_exit_status() + if args.clean is not None: + if args.clean==0: + print("Cleaning TorchStudio cache...", file=sys.stderr) + stdin, stdout, stderr = sshclient.exec_command('rm -r -f TorchStudio/cache') + exit_status = stdout.channel.recv_exit_status() + stdin, stdout, stderr = sshclient.exec_command('rmdir /s /q TorchStudio\cache') + exit_status = stdout.channel.recv_exit_status() + if args.clean==1: + print("Deleting TorchStudio environment...", file=sys.stderr) + stdin, stdout, stderr = sshclient.exec_command('rm -r -f TorchStudio/python') + exit_status = stdout.channel.recv_exit_status() + stdin, stdout, stderr = sshclient.exec_command('rmdir /s /q TorchStudio\python') + exit_status = stdout.channel.recv_exit_status() + if args.clean==2: + print("Deleting all TorchStudio files...", file=sys.stderr) + stdin, stdout, stderr = sshclient.exec_command('rm -r -f TorchStudio') + exit_status = stdout.channel.recv_exit_status() + stdin, stdout, stderr = sshclient.exec_command('rmdir /s /q TorchStudio') + exit_status = stdout.channel.recv_exit_status() sshclient.close() print("Cleaning complete") exit(0) @@ -321,6 +334,11 @@ def finish(self): if args.command: if args.script: print("Launching remote script...", file=sys.stderr) + if '\\python' in args.command: #python on Windows, add path variables + python_root=args.command[:args.command.rindex('\\python')] + if '&&' in python_root: + python_root=python_root[python_root.rindex('&&')+2:] + args.command='set PATH=%PATH%;'+python_root+';'+python_root+'\\Library\\mingw-w64\\bin;'+python_root+'\\Library\\bin;'+python_root+'\\bin&&'+args.command stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+args.command+" -u -X utf8 -m "+' '.join([args.script]+other_args)) else: print("Executing remote command...", file=sys.stderr) @@ -329,10 +347,10 @@ def finish(self): while not stdout.channel.exit_status_ready(): time.sleep(.01) #lower CPU usage if stdout.channel.recv_stderr_ready(): - sys.stderr.buffer.write(stdout.channel.recv_stderr(1024).replace(b'\r\n',b'\n')) + sys.stderr.buffer.write(stdout.channel.recv_stderr(8192)) time.sleep(.01) #for stdout/stderr sync if stdout.channel.recv_ready(): - sys.stdout.buffer.write(stdout.channel.recv(1024).replace(b'\r\n',b'\n')) + sys.stdout.buffer.write(stdout.channel.recv(8192)) time.sleep(.01) #for stdout/stderr sync else: if args.script: diff --git a/torchstudio/tcpcodec.py b/torchstudio/tcpcodec.py index f0013fb..2625d4a 100644 --- a/torchstudio/tcpcodec.py +++ b/torchstudio/tcpcodec.py @@ -43,24 +43,18 @@ def connect(server_address=None, timeout=0): return sock def send_msg(sock, data_type, data = bytearray()): - def sendall(sock, data): - while len(data) >0: - try: - ret = sock.send(data[:1048576]) #1MB chunks - except: - print("Lost connection (send timeout)", file=sys.stderr) - exit() - if ret == 0: - print("Lost connection (send null)", file=sys.stderr) - exit() - else: - data=data[ret:] - type_bytes=bytes(data_type, 'utf-8') type_size=len(type_bytes) - msg = struct.pack(f'