From 8819abe252579b610cb48c5ad83d51b45a90ddba Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 7 Aug 2024 06:34:51 -0400 Subject: [PATCH] Add `insertdims` (#830) --- README.md | 3 +++ src/Compat.jl | 31 +++++++++++++++++++++++++++++++ test/runtests.jl | 31 +++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/README.md b/README.md index a3ecd95a5..39eb6dd10 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,8 @@ changes in `julia`. ## Supported features +* `insertdims(D; dims)` is the opposite of `dropdims` ([#45793]) (since Compat 4.16.0) + * `Compat.Fix{N}` which fixes an argument at the `N`th position ([#54653]) (since Compat 4.16.0) * `chopprefix(s, prefix)` and `chopsuffix(s, suffix)` ([#40995]) (since Compat 4.15.0) @@ -190,6 +192,7 @@ Note that you should specify the correct minimum version for `Compat` in the [#43852]: https://github.com/JuliaLang/julia/issues/43852 [#45052]: https://github.com/JuliaLang/julia/issues/45052 [#45607]: https://github.com/JuliaLang/julia/issues/45607 +[#45793]: https://github.com/JuliaLang/julia/issues/45793 [#47354]: https://github.com/JuliaLang/julia/issues/47354 [#47679]: https://github.com/JuliaLang/julia/pull/47679 [#48038]: https://github.com/JuliaLang/julia/issues/48038 diff --git a/src/Compat.jl b/src/Compat.jl index 77d96a6eb..1f5260121 100644 --- a/src/Compat.jl +++ b/src/Compat.jl @@ -1122,6 +1122,37 @@ if VERSION < v"1.8.0-DEV.1016" export chopprefix, chopsuffix end +if VERSION < v"1.12.0-DEV.974" # contrib/commit-name.sh 2635dea + + insertdims(A; dims) = _insertdims(A, dims) + + function _insertdims(A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {T, N, M} + for i in eachindex(dims) + 1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1")) + dims[i] ≤ N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added")) + for j = 1:i-1 + dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique")) + end + end + + # acc is a tuple, where the first entry is the final shape + # the second entry off acc is a counter for the axes of A + inds = Base._foldoneto((acc, i) -> + i ∈ dims + ? ((acc[1]..., Base.OneTo(1)), acc[2]) + : ((acc[1]..., axes(A, acc[2])), acc[2] + 1), + ((), 1), Val(N+M)) + new_shape = inds[1] + return reshape(A, new_shape) + end + + _insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),)) + + export insertdims +else + using Base: insertdims, _insertdims +end + # https://github.com/JuliaLang/julia/pull/54653: add Fix @static if !isdefined(Base, :Fix) # VERSION < v"1.12.0-DEV.981" @static if !isdefined(Base, :_stable_typeof) diff --git a/test/runtests.jl b/test/runtests.jl index 4d2982148..e526f0128 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -908,6 +908,37 @@ end end end +# https://github.com/JuliaLang/julia/pull/45793 +@testset "insertdims" begin + a = rand(8, 7) + @test @inferred(insertdims(a, dims=1)) == @inferred(insertdims(a, dims=(1,))) == reshape(a, (1, 8, 7)) + @test @inferred(insertdims(a, dims=3)) == @inferred(insertdims(a, dims=(3,))) == reshape(a, (8, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 1, 7)) + @test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 1, 1, 8, 7)) + @test @inferred(insertdims(a, dims=(1, 4))) == reshape(a, (1, 8, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 3, 5))) == reshape(a, (1, 8, 1, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 2, 4, 6))) == reshape(a, (1, 1, 8, 1, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 3, 4, 6))) == reshape(a, (1, 8, 1, 1, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 4, 6, 3))) == reshape(a, (1, 8, 1, 1, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 3, 5, 6))) == reshape(a, (1, 8, 1, 7, 1, 1)) + + @test_throws ArgumentError insertdims(a, dims=(1, 1, 2, 3)) + @test_throws ArgumentError insertdims(a, dims=(1, 2, 2, 3)) + @test_throws ArgumentError insertdims(a, dims=(1, 2, 3, 3)) + @test_throws UndefKeywordError insertdims(a) + @test_throws ArgumentError insertdims(a, dims=0) + @test_throws ArgumentError insertdims(a, dims=(1, 2, 1)) + @test_throws ArgumentError insertdims(a, dims=4) + @test_throws ArgumentError insertdims(a, dims=6) + + # insertdims and dropdims are inverses + b = rand(1,1,1,5,1,1,7) + for dims in [1, (1,), 2, (2,), 3, (3,), (1,3), (1,2,3), (1,2), (1,3,5), (1,2,5,6), (1,3,5,6), (1,3,5,6), (1,6,5,3)] + @test dropdims(insertdims(a; dims); dims) == a + @test insertdims(dropdims(b; dims); dims) == b + end +end + # https://github.com/JuliaLang/julia/pull/54653: add Fix @testset "Fix" begin function test_fix1(Fix1=Compat.Fix1)