Sensible extensions for exposing torch in Julia.
This package is aimed at providing the Tensor
type, which offloads all computations over to PyTorch.
Note:
- Needs a machine with a CUDA GPU (CUDA 10.1 or above)
- will need lazy artifacts function without a GPU
To add the package, from the Julia REPL, enter the Pkg prompt by typing ]
and execute the following:
pkg> add Torch
Or via Julia's package manager Pkg.
julia> using Pkg; Pkg.add("Torch");
using Metalhead, Metalhead.Flux, Torch
using Torch: torch
resnet = ResNet()
We can move our object over to Torch via a simple call to torch
tresnet = resnet.layers |> torch
Or if we need more control over the device to be used like so:
ip = rand(Float32, 224, 224, 3, 1) # An RGB Image
tip = tensor(ip, dev = 0) # 0 => GPU:0 in Torch
cpu_tensor = tensor(ip, dev = -1) # -1 => CPU:0
Calling into the model is done via the usual Flux mechanism.
tresnet(tip);
We can take gradients using Zygote as well
gs = gradient(x -> sum(tresnet(x)), tip);
# Or
ps = Flux.params(tresnet);
gs = gradient(ps) do
sum(tresnet(tip))
end
Please feel free to open issues you might encounter in the issue tracker. I would also appreciate contributions through PRs toward corrections, increased coverage, docs, etc. Testing currently runs on Linux, but that can be expanded as need arises.
Takes a lot of inspiration from existing such projects - ocaml-torch for generating the wrappers.