diff --git a/torchstudio/datasets/randomgenerator.py b/torchstudio/datasets/randomgenerator.py index fbb7b9b..9d4845e 100644 --- a/torchstudio/datasets/randomgenerator.py +++ b/torchstudio/datasets/randomgenerator.py @@ -10,7 +10,7 @@ class RandomGenerator(Dataset): Size of the dataset (number of samples) tensors: A list of tuples defining tensor properties: shape, type, range - All properties are optionals. Defaults are null, float, [0,1] + All properties are optionals. Defaults are null, torch.float, [0,1] """ def __init__(self, size:int=256, tensors=[(3,64,64), (int,[0,9])]): @@ -29,12 +29,12 @@ def __getitem__(self, idx): sample = [] for properties in self.tensors: shape=[] - dtype=float + dtype=torch.float drange=[0,1] for property in properties: if type(property)==int: shape.append(property) - elif inspect.isclass(property): + elif type(property)==type or type(property)==torch.dtype: dtype=property elif type(property) is list: drange=property diff --git a/torchstudio/modeltrain.py b/torchstudio/modeltrain.py index 80ec6ae..3755d24 100644 --- a/torchstudio/modeltrain.py +++ b/torchstudio/modeltrain.py @@ -237,6 +237,11 @@ def deepcopy_cpu(value): scaler = torch.cuda.amp.GradScaler() if mode=='BF16': os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" #https://discuss.pytorch.org/t/bfloat16-has-worse-performance-than-float16-for-conv2d/154373 + train_type=None + if mode=='FP16': + train_type=torch.float16 + if mode=='BF16': + train_type=torch.bfloat16 print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr) if msg_type == 'TrainOneEpoch' and modules_valid: @@ -252,7 +257,7 @@ def deepcopy_cpu(value): targets = [tensors[i].to(device) for i in output_tensors_id] optimizer.zero_grad() - with torch.autocast(device_type='cuda' if 'cuda' in device_id else 'cpu', dtype=torch.bfloat16 if mode=='BF16' else torch.float16, enabled=True if '16' in mode else False): + with torch.autocast(device_type='cuda' if 'cuda' in device_id else 'cpu', dtype=train_type, enabled=True if train_type else False): outputs = model(*inputs) outputs = outputs if type(outputs) is not torch.Tensor else [outputs] loss = 0