You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I got past that by using a newer version of torch (2.5.0) and transformers (4.43.3). Using Flair version 0.13.1 or 0.14.0 gives me the following issue when training a model ...
Traceback (most recent call last):
File "/Users/xxxxxx/Dev/flairNLP/train-models/train_model.py", line 165, in
main()
File "/Users/xxxxxx/Dev/flairNLP/train-models/train_model.py", line 138, in main
trainer.train(
File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/flair/trainers/trainer.py", line 200, in train
return self.train_custom(**local_variables, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/flair/trainers/trainer.py", line 600, in train_custom
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 230, in init
dtype = torch.get_autocast_dtype(device_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: unsupported scalarType
Describe the bug
When setting
flair.device
tomps
, the following error is thrown during training:To Reproduce
Expected behavior
Torch's mps support should be usable via flair.
Logs and Stack traces
No response
Screenshots
No response
Additional Context
No response
Environment
Versions:
Flair
0.13.1
Pytorch
2.3.1
Transformers
4.42.4
GPU
False
The text was updated successfully, but these errors were encountered: