Skip to content

Commit

Permalink
0.9.11
Browse files Browse the repository at this point in the history
  • Loading branch information
divideconcept authored Oct 7, 2022
1 parent bde996c commit 6815d55
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 42 deletions.
6 changes: 3 additions & 3 deletions torchstudio/datasetload.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __getitem__(self, id):

sshclient = paramiko.SSHClient()
sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5)
sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=10)
worker_socket = socket.socket()
worker_socket.bind(('localhost', 0))
freeport=worker_socket.getsockname()[1]
Expand All @@ -240,7 +240,7 @@ def __getitem__(self, id):
port=freeport

try:
worker_socket = tc.connect((address,port))
worker_socket = tc.connect((address,port),timeout=10)
num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0)
tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples))
tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id))
Expand Down Expand Up @@ -273,7 +273,7 @@ def __getitem__(self, id):

if sshaddress and sshport and username:
try:
forward_tunnel.stop()
del forward_tunnel
except:
pass
try:
Expand Down
38 changes: 31 additions & 7 deletions torchstudio/modeltrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def deepcopy_cpu(value):
device = torch.device(device_id)
pin_memory = True if 'cuda' in device_id else False

if msg_type == 'SetMode':
print("Setting mode...\n", file=sys.stderr)
mode=tc.decode_strings(msg_data)[0]

if msg_type == 'SetTorchScriptModel' and modules_valid:
if msg_data:
print("Setting torchscript model...\n", file=sys.stderr)
Expand Down Expand Up @@ -224,9 +228,19 @@ def deepcopy_cpu(value):
valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size, shuffle=False, pin_memory=pin_memory)

if msg_type == 'StartTraining' and modules_valid:
scaler = None
if 'cuda' in device_id:
#https://pytorch.org/docs/stable/notes/cuda.html
torch.backends.cuda.matmul.allow_tf32 = True if mode=='TF32' else False
torch.backends.cudnn.allow_tf32 = True
if mode=='FP16':
scaler = torch.cuda.amp.GradScaler()
if mode=='BF16':
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" #https://discuss.pytorch.org/t/bfloat16-has-worse-performance-than-float16-for-conv2d/154373
print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr)

if msg_type == 'TrainOneEpoch' and modules_valid:

#training
model.train()
train_loss = 0
Expand All @@ -237,13 +251,23 @@ def deepcopy_cpu(value):
inputs = [tensors[i].to(device) for i in input_tensors_id]
targets = [tensors[i].to(device) for i in output_tensors_id]
optimizer.zero_grad()
outputs = model(*inputs)
outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
loss = 0
for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
loss = loss + criterion(output, target)
loss.backward()
optimizer.step()

with torch.autocast(device_type='cuda' if 'cuda' in device_id else 'cpu', dtype=torch.bfloat16 if mode=='BF16' else torch.float16, enabled=True if '16' in mode else False):
outputs = model(*inputs)
outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
loss = 0
for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
loss = loss + criterion(output, target)

if scaler:
# Accumulates scaled gradients.
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()

train_loss += loss.item() * inputs[0].size(0)

with torch.set_grad_enabled(False):
Expand Down
46 changes: 27 additions & 19 deletions torchstudio/pythoncheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
args, unknown = parser.parse_known_args()

#check python version first
python_version=(sys.version_info.major,sys.version_info.minor)
min_python_version=(3,7) if args.remote else (3,8) #3.7 required for ordered dicts and stdout/stderr utf8 encoding, 3.8 required for python parsing
python_version=(sys.version_info.major,sys.version_info.minor,sys.version_info.micro)
min_python_version=(3,7,0) if args.remote else (3,8,0) #3.7 required for ordered dicts and stdout/stderr utf8 encoding, 3.8 required for python parsing
if python_version<min_python_version:
print("Error: Python "+str(min_python_version[0])+"."+str(min_python_version[1])+" minimum is required.", file=sys.stderr)
print("This environment has Python "+str(python_version[0])+"."+str(python_version[1])+".", file=sys.stderr)
print("Error: Python "+'.'.join((str(i) for i in min_python_version))+" minimum is required.", file=sys.stderr)
print("This environment has Python "+'.'.join((str(i) for i in python_version))+".", file=sys.stderr)
exit(1)

print("Checking required packages...\n", file=sys.stderr)
Expand All @@ -36,15 +36,15 @@
if module is None:
missing_modules.append(module_check)
elif module_check=='torch':
if python_version<(3,8):
if python_version<(3,8,0):
from importlib_metadata import version
else:
from importlib.metadata import version
pytorch_version=tuple(int(i) for i in version('torch').split('.')[:2])
min_pytorch_version=(1,9) #1.9 required for torch.package support, 1.10 preferred for stable torch.fx and profile-directed typing in torchscript
pytorch_version=tuple(int(i) if i.isdigit() else 0 for i in version('torch').split('.')[:3])
min_pytorch_version=(1,9,0) #1.9 required for torch.package support, 1.10 preferred for stable torch.fx and profile-directed typing in torchscript
if pytorch_version<min_pytorch_version:
print("Error: PyTorch "+str(min_pytorch_version[0])+"."+str(min_pytorch_version[1])+" minimum is required.", file=sys.stderr)
print("This environment has PyTorch "+str(pytorch_version[0])+"."+str(pytorch_version[1])+".", file=sys.stderr)
print("Error: PyTorch "+'.'.join((str(i) for i in min_pytorch_version))+" minimum is required.", file=sys.stderr)
print("This environment has PyTorch "+'.'.join((str(i) for i in pytorch_version))+".", file=sys.stderr)
exit(1)

if len(missing_modules)>0:
Expand Down Expand Up @@ -72,24 +72,32 @@

#finally, list available devices
print("Loading PyTorch...\n", file=sys.stderr)

import torch

print("Listing devices...\n", file=sys.stderr)

devices = {}
devices['cpu'] = {'name': 'CPU', 'pin_memory': False}
devices['cpu'] = {'name': 'CPU', 'pin_memory': False, 'modes': ['FP32']}
for i in range(torch.cuda.device_count()):
devices['cuda:'+str(i)] = {'name': torch.cuda.get_device_name(i), 'pin_memory': True}
if pytorch_version>=(1,12):
modes = ['FP32']
#same as torch.cuda.is_bf16_supported() but compatible with PyTorch<1.10, and not limited to current cuda device only
cu_vers = torch.version.cuda
if cu_vers is not None:
cuda_maj_decide = int(cu_vers.split('.')[0]) >= 11
else:
cuda_maj_decide = False
compute_capability=torch.cuda.get_device_properties(torch.cuda.device(i)).major #https://developer.nvidia.com/cuda-gpus
if compute_capability>=8 and cuda_maj_decide: #RTX 3000 and higher
modes+=['TF32','FP16','BF16']
if compute_capability==7: #RTX 2000
modes+=['FP16']
devices['cuda:'+str(i)] = {'name': torch.cuda.get_device_name(i), 'pin_memory': True, 'modes': modes}
if pytorch_version>=(1,12,0):
if torch.backends.mps.is_available():
devices['mps'] = {'name': 'Metal Acceleration', 'pin_memory': False}
devices['mps'] = {'name': 'Metal', 'pin_memory': False, 'modes': ['FP32']}
#other possible devices:
#'hpu' (https://docs.habana.ai/en/latest/PyTorch_User_Guide/PyTorch_User_Guide.html)
#'dml' (https://docs.microsoft.com/en-us/windows/ai/directml/gpu-pytorch-windows)
devices_string_list=[]
for id in devices:
devices_string_list.append(devices[id]['name']+" ("+id+")")
print(("Online and functional " if args.remote else "Functional ")+"("+platform.platform()+", Python "+str(python_version[0])+"."+str(python_version[1])+", PyTorch "+str(pytorch_version[0])+"."+str(pytorch_version[1])+", Devices: "+", ".join(devices_string_list)+")");


devices_string_list.append(id+' "'+devices[id]['name']+'" ('+'/'.join(devices[id]['modes'])+')')
print("Ready ("+platform.platform()+", Python "+'.'.join((str(i) for i in python_version))+", PyTorch "+'.'.join((str(i) for i in pytorch_version))+", Devices: "+", ".join(devices_string_list)+")");
18 changes: 11 additions & 7 deletions torchstudio/sshtunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import io

# Port forwarding from https://github.com/skyleronken/sshrat/blob/master/tunnels.py
# improved with dynamic local port allocation feedback for reverse tunnel with a null local port
# improved with:
# dynamic local port allocation feedback for reverse tunnel with a null local port
# blocking connections to avoid connection lost with poor cloud servers
# more explicit error messages
import threading
import socket
import selectors
Expand Down Expand Up @@ -322,14 +325,15 @@ def finish(self):
else:
print("Executing remote command...", file=sys.stderr)
stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+' '.join([args.command]+other_args))
while True:
time.sleep(.1)

while not stdout.channel.exit_status_ready():
time.sleep(.01) #lower CPU usage
if stdout.channel.recv_stderr_ready():
sys.stderr.write(str(stdout.channel.recv_stderr(1024).replace(b'\r\n',b'\n'),'utf-8'))
sys.stderr.buffer.write(stdout.channel.recv_stderr(1024).replace(b'\r\n',b'\n'))
time.sleep(.01) #for stdout/stderr sync
if stdout.channel.recv_ready():
sys.stdout.write(str(stdout.channel.recv(1024).replace(b'\r\n',b'\n'),'utf-8'))
if stdout.channel.exit_status_ready():
break
sys.stdout.buffer.write(stdout.channel.recv(1024).replace(b'\r\n',b'\n'))
time.sleep(.01) #for stdout/stderr sync
else:
if args.script:
print("Error: no python environment set.", file=sys.stderr)
Expand Down
26 changes: 20 additions & 6 deletions torchstudio/tcpcodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,55 @@ def start_server(server):
conn, addr = server.accept()
return conn

def connect(server_address=None):
def connect(server_address=None, timeout=0):
if server_address==None and len(sys.argv)<3:
print("Missing socket address and port")
print("Missing socket address and port", file=sys.stderr)
exit()

if not server_address:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--address", help="server address", type=str, default='localhost')
parser.add_argument("--port", help="local port to which the script must connect", type=int, default=0)
parser.add_argument("--timeout", help="max number of seconds without incoming messages before quitting", type=int, default=0)
args, unknown = parser.parse_known_args()
server_address = (args.address, args.port)
timeout=args.timeout
else:
server_address = (server_address[0], int(server_address[1]))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect(server_address)
except socket.error as serr:
print("Connection error: %s" % str(serr))
print("Connection error: %s" % str(serr), file=sys.stderr)
exit()

if timeout>0:
sock.settimeout(timeout)
return sock

def send_msg(sock, type, data = bytearray()):
type_bytes=bytes(type, 'utf-8')
type_size=len(type_bytes)
msg = struct.pack(f'<B{type_size}sI', type_size, type_bytes, len(data)) + data
sock.sendall(msg)
try:
sock.sendall(msg)
except:
print("Lost connection", file=sys.stderr)
exit()

def recv_msg(sock):
def recvall(sock, n):
data = bytearray()
while len(data) < n:
packet = sock.recv(n - len(data))
try:
packet = sock.recv(n - len(data))
except:
print("Lost connection", file=sys.stderr)
exit()
if len(packet)==0:
print("Lost connection", file=sys.stderr)
exit()
data.extend(packet)
return data
type_size = struct.unpack('<B', recvall(sock, 1))[0]
Expand Down

0 comments on commit 6815d55

Please sign in to comment.