Skip to content

Commit

Permalink
0.9.14
Browse files Browse the repository at this point in the history
  • Loading branch information
divideconcept authored Jan 2, 2023
1 parent a514e8f commit 1804e96
Show file tree
Hide file tree
Showing 19 changed files with 535 additions and 248 deletions.
28 changes: 17 additions & 11 deletions torchstudio/datasetanalyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,23 @@
print("Analyzing...\n", file=sys.stderr)

analysis_server, address = tc.generate_server()
tc.send_msg(app_socket, 'ServerRequestingDataset', tc.encode_strings(address))

dataset_socket=tc.start_server(analysis_server)

tc.send_msg(dataset_socket, 'RequestMetaInfos')

if analyzer_env['analyzer'].train is None:
request_msg='AnalysisServerRequestingAllSamples'
request_msg='RequestAllSamples'
elif analyzer_env['analyzer'].train==True:
request_msg='AnalysisServerRequestingTrainingSamples'
request_msg='RequestTrainingSamples'
elif analyzer_env['analyzer'].train==False:
request_msg='AnalysisServerRequestingValidationSamples'
tc.send_msg(app_socket, request_msg, tc.encode_strings(address))
dataset_socket=tc.start_server(analysis_server)
request_msg='RequestValidationSamples'
tc.send_msg(dataset_socket, request_msg, tc.encode_strings(address))

while True:
dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)

if dataset_msg_type == 'NumSamples':
num_samples=tc.decode_ints(dataset_msg_data)[0]
pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters

if dataset_msg_type == 'InputTensorsID':
input_tensors_id=tc.decode_ints(dataset_msg_data)

Expand All @@ -53,6 +53,10 @@
if dataset_msg_type == 'Labels':
labels=tc.decode_strings(dataset_msg_data)

if dataset_msg_type == 'NumSamples':
num_samples=tc.decode_ints(dataset_msg_data)[0]
pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters

if dataset_msg_type == 'StartSending':
error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition')
if error_msg is not None:
Expand Down Expand Up @@ -85,7 +89,7 @@
if dataset_msg_type == 'DoneSending':
pbar.close()
error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition')
tc.send_msg(dataset_socket, 'DoneReceiving')
tc.send_msg(dataset_socket, 'DisconnectFromWorkerServer')
dataset_socket.close()
analysis_server.close()
if error_msg is not None:
Expand All @@ -106,12 +110,14 @@

if msg_type == 'RequestAnalysisReport':
resolution = tc.decode_ints(msg_data)
if 'analyzer' in analyzer_env:
if 'analyzer' in analyzer_env and resolution[0]>0 and resolution[1]>0:
error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition')
if error_msg is not None:
print(error_msg, file=sys.stderr)
if return_value is not None:
tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value))
else:
tc.send_msg(app_socket, 'ReportImage')

if msg_type == 'Exit':
break
74 changes: 46 additions & 28 deletions torchstudio/datasetload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time
from collections.abc import Iterable
from tqdm.auto import tqdm
import hashlib

#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error
import ssl
Expand Down Expand Up @@ -207,9 +208,7 @@ def __getitem__(self, id):
if msg_type == 'OutputTensorsID':
output_tensors_id = tc.decode_ints(msg_data)

if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples':
train_set=True if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendAllSamples' else False
valid_set=True if msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples' else False
if msg_type == 'ConnectToWorkerServer':
name, sshaddress, sshport, username, password, keydata, address, port = tc.decode_strings(msg_data)
port=int(port)

Expand Down Expand Up @@ -241,30 +240,49 @@ def __getitem__(self, id):

try:
worker_socket = tc.connect((address,port),timeout=10)
num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0)
tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples))
tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id))
tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id))
tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes))

tc.send_msg(worker_socket, 'StartSending')
with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar:
if train_set:
meta_dataset.train()
for i in range(len(meta_dataset)):
tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i]))
pbar.update(1)
if valid_set:
meta_dataset.valid()
for i in range(len(meta_dataset)):
tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i]))
pbar.update(1)

tc.send_msg(worker_socket, 'DoneSending')
train_msg_type, train_msg_data = tc.recv_msg(worker_socket)
if train_msg_type == 'DoneReceiving':
worker_socket.close()
print('Samples transfer to '+name+' completed')
while True:
worker_msg_type, worker_msg_data = tc.recv_msg(worker_socket)

if worker_msg_type == 'RequestMetaInfos':
tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id))
tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id))
tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes))

if worker_msg_type == 'RequestHash':
dataset_hash = hashlib.md5()
dataset_hash.update(int(len(meta_dataset.train())).to_bytes(4, 'little'))
if len(meta_dataset)>0:
dataset_hash.update(tc.encode_torch_tensors(meta_dataset[0]))
dataset_hash.update(int(len(meta_dataset.valid())).to_bytes(4, 'little'))
if len(meta_dataset)>0:
dataset_hash.update(tc.encode_torch_tensors(meta_dataset[0]))
tc.send_msg(worker_socket, 'DatasetHash', dataset_hash.digest())

if worker_msg_type == 'RequestTrainingSamples' or worker_msg_type == 'RequestValidationSamples' or worker_msg_type == 'RequestAllSamples':
train_set=True if worker_msg_type == 'RequestTrainingSamples' or worker_msg_type == 'RequestAllSamples' else False
valid_set=True if worker_msg_type == 'RequestValidationSamples' or worker_msg_type == 'RequestAllSamples' else False
num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0)
tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples))

tc.send_msg(worker_socket, 'StartSending')
with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar:
if train_set:
meta_dataset.train()
for i in range(len(meta_dataset)):
tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i]))
pbar.update(1)
if valid_set:
meta_dataset.valid()
for i in range(len(meta_dataset)):
tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i]))
pbar.update(1)

tc.send_msg(worker_socket, 'DoneSending')

if worker_msg_type == 'DisconnectFromWorkerServer':
worker_socket.close()
print('Samples transfer to '+name+' completed')
break

except:
if sshaddress and sshport and username:
Expand All @@ -277,7 +295,7 @@ def __getitem__(self, id):
except:
pass
try:
sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong
sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DisconnectFromWorkerServer ping pong
except:
pass

Expand Down
168 changes: 168 additions & 0 deletions torchstudio/datasets/genericloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision
import torchaudio
import numpy as np
import sys

class GenericLoader(Dataset):
"""A generic dataset loader.
Suitable for classification, segmentation and regression datasets.
Supports image, audio, and numpy array files.
Args:
path (str):
path to the dataset
classification (bool):
True: classification dataset (single class prediction: class1, class2, ...)
False: segmentation or regression dataset (multiple components: input, target, ...)
separator (str or None):
'/': folders will be used to determine classes or components
(classes: class1/1.ext, class1/2.ext, class2/1.ext, class2/2.ext, ...)
(components: inputs/1.ext, inputs/2.ext, targets/1.ext, targets/2.ext, ...)
'_' or other separator: file name parts will be used to determine classes or components
(classes: class1_1.ext, class1_2.ext, class2_1.ext, class2_2.ext, ...)
(components: 1_input.ext, 1_output.ext, 2_input.ext, 2_output.ext, ...)
'' or None: file names or their content will be used to determine components
(one sample per folder: 1/input.ext, 1/output.ext, 2/input.ext, 2/output.ext, ...)
(samples in one folder: 1.ext, 2.ext, ...)
extensions (str):
file extension to filters (such as: .jpg, .jpeg, .png, .mp3, .wav, .npy, .npz)
transforms (list):
list of transforms to apply to the different components of each sample (use None is some components need no transform)
(ie: [torchvision.transforms.Compose([transforms.Resize(64)]), torchaudio.transforms.Spectrogram()])
"""

def __init__(self, path:str='', classification:bool=True, separator:str='/', extensions:str='.jpg, .jpeg, .png, .mp3, .wav, .npy, .npz', transforms=[]):
exts = tuple(extensions.replace(' ','').split(','))
paths = []
self.samples = []
self.classes = []
self.transforms = transforms
if not os.path.exists(path):
print("Path not found.", file=sys.stderr)
return
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(exts):
paths.append(os.path.join(root, file).replace('\\','/'))
paths=sorted(paths)
if not paths:
print("No files found.", file=sys.stderr)
return
self.classification=classification
if classification:
if separator == '/':
for path in paths:
class_name=path.split('/')[-2]
if class_name not in self.classes:
self.classes.append(class_name)
self.samples.append([path, self.classes.index(class_name)])
elif separator:
for path in paths:
class_name = path.split('/')[-1].split(separator)[0]
if class_name not in self.classes:
self.classes.append(class_name)
self.samples.append([path, self.classes.index(class_name)])
else:
print("You need a separator with classication datasets", file=sys.stderr)
return
else:
samples_index = dict()
if separator == '/':
for path in paths:
components_name=path.split('/')[-2]
sample_name = path.split('/')[-1].split('.')[-2]
if sample_name not in samples_index:
samples_index[sample_name] = len(self.samples)
self.samples.append([])
self.samples[samples_index[sample_name]].append(path)
elif separator:
for path in paths:
components_name = path.split('.')[-2].split(separator)[-1]
sample_name = path.split('/')[-1].split(separator)[0]
if sample_name not in samples_index:
samples_index[sample_name] = len(self.samples)
self.samples.append([])
self.samples[samples_index[sample_name]].append(path)
else:
single_folder=True
file_root=path[:path.rfind("/")]
for path in paths:
if not path.startswith(file_root):
single_folder=False
break
if single_folder:
for path in paths:
sample_name = path.split('/')[-1].split('.')[-2]
if sample_name not in samples_index:
samples_index[sample_name] = len(self.samples)
self.samples.append([])
self.samples[samples_index[sample_name]].append(path)
else:
for path in paths:
components_name = path.split('/')[-1].split('.')[-2]
sample_name = path.split('/')[-2]
if sample_name not in samples_index:
samples_index[sample_name] = len(self.samples)
self.samples.append([])
self.samples[samples_index[sample_name]].append(path)

def to_tensors(self, path:str):
if path.endswith('.jpg') or path.endswith('.jpeg') or path.endswith('.png'):
img=Image.open(path)
if img.getpalette():
return [torch.from_numpy(np.array(img, dtype=np.uint8))]
else:
trans=torchvision.transforms.ToTensor()
return [trans(img)]

if path.endswith('.mp3') or path.endswith('.wav'):
waveform, sample_rate = torchaudio.load(path)
return [waveform]

if path.endswith('.npy') or path.endswith('.npz'):
arrays = np.load(path)
if type(arrays) == dict:
tensors = []
for array in arrays:
tensors.append(torch.from_numpy(arrays[array]))
return tensors
else:
return [torch.from_numpy(arrays)]

def __len__(self):
return len(self.samples)

def __getitem__(self, id):
"""
Returns:
A tuple of tensors.
"""

if id < 0 or id >= len(self):
raise IndexError

components = []
for component in self.samples[id]:
if type(component) is str:
components.extend(self.to_tensors(component))
else:
components.extend([torch.tensor(component)])

if self.transforms:
if type(self.transforms) is not list and type(self.transforms) is not tuple:
self.transforms = [self.transforms]
for i, transform in enumerate(self.transforms):
if i < len(components) and transform is not None:
components[i] = transform(components[i])

return tuple(components)
2 changes: 1 addition & 1 deletion torchstudio/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Accuracy(Metric):
threshold: error threshold below which predictions are considered accurate (not used in multiclass)
normalize: if set to True, normalize predictions with sigmoid or softmax before calculating the accuracy
"""
def __init__(self, threshold: float = 0.1, normalize: bool = False):
def __init__(self, threshold: float = 0.01, normalize: bool = False):
self.threshold = threshold
self.normalize = normalize
self.reset()
Expand Down
Loading

0 comments on commit 1804e96

Please sign in to comment.