-
-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overhaul of ResNet API #174
Conversation
Some perks of the new API: 0.7.2: julia> model = ResNet(50);
julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
BenchmarkTools.Trial: 1 sample with 1 evaluation.
Single result which took 6.698 s (87.06% GC) to evaluate,
with a memory estimate of 2.46 GiB, over 47810 allocations.
julia> model = ResNet(18);
julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
BenchmarkTools.Trial: 2 samples with 1 evaluation.
Range (min … max): 2.576 s … 2.580 s ┊ GC (min … max): 87.60% … 87.65%
Time (median): 2.578 s ┊ GC (median): 87.63%
Time (mean ± σ): 2.578 s ± 2.770 ms ┊ GC (mean ± σ): 87.63% ± 0.03%
█ █
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
2.58 s Histogram: frequency by time 2.58 s <
Memory estimate: 1.01 GiB, allocs estimate: 19594. This PR: julia> model = ResNet(50);
julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
BenchmarkTools.Trial: 1 sample with 1 evaluation.
Single result which took 5.644 s (85.62% GC) to evaluate,
with a memory estimate of 2.50 GiB, over 45095 allocations.
julia> model = ResNet(18);
julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
BenchmarkTools.Trial: 13 samples with 1 evaluation.
Range (min … max): 338.901 ms … 612.421 ms ┊ GC (min … max): 4.01% … 46.50%
Time (median): 345.959 ms ┊ GC (median): 5.21%
Time (mean ± σ): 416.913 ms ± 90.533 ms ┊ GC (mean ± σ): 21.52% ± 16.19%
█▄ ▁
██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▆▆▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
339 ms Histogram: frequency by time 612 ms <
Memory estimate: 1.03 GiB, allocs estimate: 17275. Julia version info: julia> versioninfo()
Julia Version 1.9.0-DEV.840
Commit 68d62ab3d3 (2022-06-22 21:39 UTC)
Platform Info:
OS: macOS (arm64-apple-darwin21.5.0)
CPU: 8 × Apple M1
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.5 (ORCJIT, apple-m1)
Threads: 4 on 4 virtual cores
Environment:
JULIA_NUM_THREADS = 4 |
How is it that Zygote seems to be getting worse with passing Julia versions, though? I could've sworn it wasn't this bad a couple of weeks ago, and today it seems to be struggling to even calculate a ResNet-50 gradient? |
I'm confused how the new API contributes to better gradient times (at least for ResNet where there are no new layers added, right)? A couple high level comments as you work on this:
|
Well to be completely honest, I'm not sure, but I have some theories that mostly revolve around how the nested
This makes sense. We could make these
👍🏽
This...might take time. The major issue is in terms of getting the model structures to overlap. |
I was thinking even more declarative. Just have a function called I find these kinds of declarative interfaces are more flexible and easier to keep track of mentally. But they usually take more typing. Possibly we can merge your idea with this and allow a named tuple or
Yeah, I think the
No hurry on this. Also, the script linked in the HuggingFace model cards doesn't depend on structure. It turns the Flux model into a state dict-like dictionary then just iterates the keys together with the PyTorch state dict. It might just work for your model since the |
I was trying to come up with a more declarative API, but one of the problems that we might face is in terms of documentation. Since these blocks have a lot of arguments, directing end-users to refer to the documentation for these blocks might cause some confusion. I'm a little uncertain if that's desirable. Maybe we keep the declarative API but document the This also causes quite a bit of argument hiding (i.e. builder functions aren't explicitly accepting the arguments to be passed to the lower level ones but instead a |
Don't they just accept something like
That's okay |
I could, but this has the same problem - the function doesn't clearly "see" the |
This is a natural conflict between designing something to be flexible vs. safe. In general, Julia code tries to be more permissive, especially at the lower level API. This is what makes it possible to smash together two totally separate packages and get a useful result without too much hacking. This approach definitely requires more care, and I find the best way to work through this is to just try and be permissive until you hit a roadblock. Usually that experience is most informative about the design space. Let me try and walk through some of that process below.
They are regulated, just not by
Either way, the user gets the same error. A similar outcome will happen for invalid keywords (with a slightly more informative error too) or for positional arguments. I would argue that getting the error for (2) is more informative because it signals that it is specifically the Maybe there is a specific kind of error that you are expecting that isn't covered well here? We should discuss that case in more detail then. Also, remember that this is a fairly low-level portion of the API. There is an expectation that the user can read Julia errors here (i.e. not the same level as Documenting the interface is a related but different concern. The interface should appear intuitive by itself. I would argue that many specified but restricted keywords is not intuitive. It requires reading the docstring to understand the behavior and how each one is used / when each is ignored. On the other hand, saying "arguments passed to Here's an attempt at the docstring. Let me know what you think (and feel free to push back!). Of course, this would also require similar changes to """
resnet(block, layers, stem = somedefault(); nclasses = 1000, inchannels = 3, output_stride = 32,
reduce_first = 1, activation = relu,
norm_layer = BatchNorm, drop_rate = 0.0,
block_kwargs...)
Creates the layers of a ResNe(X)t model. If you are an end-user, you should probably use
[ResNet](@ref) instead and pass in the parameters you want to modify as optional parameters
there.
# Arguments:
- `block` / `block_kwargs`: The residual block to use in the model and the keyword arguments for it. See [basicblock](@ref) and [bottleneck](@ref) for
example. This is called like `block(inplanes, outplanes; stride, block_kwargs...)`.
- `layers`: A list of integers representing the number of blocks in each stage.
- `stem`: The initial stage that operates on the input before the residual blocks. This can be any model that accepts the input and is compatible with the blocks stage. Defaults to [`somedefault`](#).
- `nclasses`: The number of output classes. The default value is 1000.
- `inchannels`: The number of input channels to the model. The default value is 3.
- `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32.
- `reduce_first`: Reduction factor for first convolution output width of residual blocks,
Default is 1 for all architectures except SE-Nets, where it is 2.
- `activation`: The activation function to use. The default value is `relu`.
- `norm_layer`: The normalization layer to use. The default value is `BatchNorm`.
- `drop_rate`: The rate to use for `Dropout` before the fully-connected classifier stage. The default value is 0.0.
If you are an end-user trying to tweak the ResNet model, note that there is no guarantee that
all combinations of parameters will work. In particular, tweaking `block_kwargs` is not
advised unless you know what you are doing.
""" I think the line: "This is called like |
Thank you for that writeup, it does clear some stuff up! I might need to do some homework before I get back with a response, but the two blog posts in particular might be good starting points in terms of understanding programming patterns in Julia a little better. I think most of my worry revolves around making |
1. Some docs 2. Basic tests for ResNet and ResNeXt now pass
Okay, I've just pushed what I think is a more declarative interface (and it does look cleaner from the user's POV). This mostly revolves around exposing two arguments at the I'm planning to rigorously document the choices of
The docs for this are missing because I wanna make sure that this interface is something that can be agreed upon before I proceed to write it up 😅 Any feedback is welcome! |
1. Less keywords for the user to worry about 2. Delete `ResNeXt` just for now
`downsample_args` is actually redundant
Oh no. Did I manage to kill CI altogether somehow? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design looks good, mostly minor changes here and there. I've been holding off on doing a full pass through all the other non-ResNet code, so I just did that and most of my comments are in those sections.
src/convnets/resnets/core.jl
Outdated
# inplanes increases by expansion after each block | ||
inplanes = planes * expansion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# inplanes increases by expansion after each block | |
inplanes = planes * expansion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need this, though. This is calculating the change in inplanes
across blocks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I am missing something but I don't see where the output of this calculation goes? It seems unused...unless it is modifying a global which is very bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were before, unfortunately. I've pushed a change. This makes resnet_planes
return a vector instead of being a stage_idx
based callback - the reason we need this is because inplanes
needs the planes
from the previous block, not the current one, so we need to have access to that information
I've incorporated some of the docs changes, and left out the others - these will need a thorough once-over anyways, and I want to try and get those in at the same time as the devdocs and the Documenter.jl port |
Co-Authored-By: Kyle Daruwalla <[email protected]>
d1d193a
to
07c5c64
Compare
Also misc. formatting and cleanup
I've also added Wide ResNet now (easy enough). But the CI is weird. I think my filtering should work but the ResNet testset isn't executing at all |
Bump? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks done to me. I just caught a couple last doc fixes and tests.
src/convnets/resnets/core.jl
Outdated
# inplanes increases by expansion after each block | ||
inplanes = planes * expansion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I am missing something but I don't see where the output of this calculation goes? It seems unused...unless it is modifying a global which is very bad.
return Chain(stages...) | ||
end | ||
|
||
function resnet(img_dims, stem, get_layers, block_repeats::Vector{<:Integer}, connection, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring for each resnet
?
function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu; | ||
norm_layer = BatchNorm, revnorm = false, | ||
use_norm = (true, true), stride = 1, kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this needs a docstring update
04e46c0
to
72cd4a9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job @theabhirath! This is a HUGE improvement, so I appreciate all the time you put into it. I'm gonna let tests run to completion.
Yeah I think this is the longest PR in terms of review comments on this repo, but it thoroughly deserved the discussion 😄 Happy to see this one through |
I've also now made PRs to the HuggingFace repositories for the models. Once they're accepted, I'll push the updated pretrained weights links and SHAs as well. It would be good to have all the tests enabled and all the tasks ticked off 😄 |
On second thoughts, might not want this to block the PR....I want to try and use the updated torchvision weights with higher accuracies - there's been some API changes so this may take a little more time |
This PR completely re-writes the current ResNet API to make it more powerful, more extensible and to reduce code duplication.
Why this PR?
While making ResNet more fully-featured, this PR will also:
Things to do
DropBlock
and its behaviourDropPath
behaviour in detail (permissible values, calculations)ResNet
interface vs lower levelresnet
interfaceresnet
APIOther PRs to land before this one
Chain
s directly without broadcasting (Define activation functions taking arrays as input NNlib.jl#423).rand_like
andrandn_like
in MLUtils (rand_like
andrandn_like
JuliaML/MLUtils.jl#101)Miscellaneous fixes
densenet
fornblocks
to avoid hitting integer edge cases