Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Assert transform #249

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

Below is the list of transforms that are are available in this package.

# Assert

```@docs
Assert
```

## Select

```@docs
Expand Down
2 changes: 1 addition & 1 deletion src/TableTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import Distributions: quantile, cdf
import TransformsBase: assertions, isrevertible, isinvertible
import TransformsBase: apply, revert, reapply, preprocess, inverse

include("assertions.jl")
include("tabletraits.jl")
include("distributions.jl")
include("tableselection.jl")
Expand All @@ -49,6 +48,7 @@ export
reapply,

# built-in
Assert,
Select,
Reject,
Satisfies,
Expand Down
33 changes: 0 additions & 33 deletions src/assertions.jl

This file was deleted.

1 change: 1 addition & 0 deletions src/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ end
# ----------------

include("transforms/utils.jl")
include("transforms/assert.jl")
include("transforms/select.jl")
include("transforms/satisfies.jl")
include("transforms/rename.jl")
Expand Down
64 changes: 64 additions & 0 deletions src/transforms/assert.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
Assert(; cond, msg="")

Asserts all columns of the table by throwing a `AssertionError(msg)`
if `cond(column)` returns `false`, otherwise returns the input table.

The `msg` argument can be a string, or a function that receives
the column name and returns a string, e.g.: `nm -> "error in column \$nm"`.

Assert(col₁, col₂, ..., colₙ; cond, msg="")
Assert([col₁, col₂, ..., colₙ]; cond, msg="")
Assert((col₁, col₂, ..., colₙ); cond, msg="")

Asserts the selected columns `col₁`, `col₂`, ..., `colₙ`.

Assert(regex; cond, msg="")

Asserts the columns that match with `regex`.

# Examples

```julia
Assert(cond=allunique, msg="assertion error")
Assert([2, 3, 5], cond=x -> sum(x) > 100)
Assert([:b, :c, :e], cond=x -> eltype(x) <: Integer)
Assert(("b", "c", "e"), cond=allunique, msg=nm -> "error in column \$nm")
Assert(r"[bce]", cond=x -> sum(x) > 100)
```
"""
struct Assert{S<:ColumnSelector,C,M} <: StatelessFeatureTransform
selector::S
cond::C
msg::M
end

Assert(selector::ColumnSelector; cond, msg="") = Assert(selector, cond, msg)

Assert(; kwargs...) = Assert(AllSelector(); kwargs...)
Assert(cols; kwargs...) = Assert(selector(cols); kwargs...)
Assert(cols::C...; kwargs...) where {C<:Column} = Assert(selector(cols); kwargs...)

isrevertible(::Type{<:Assert}) = true

function applyfeat(transform::Assert, feat, prep)
cols = Tables.columns(feat)
names = Tables.columnnames(cols)
snames = transform.selector(names)
cond = transform.cond
msg = transform.msg

msgfun = msg isa AbstractString ? _ -> msg : msg
for name in snames
x = Tables.getcolumn(cols, name)
_assert(cond(x), msgfun(name))
end

feat, nothing
end

revertfeat(::Assert, newfeat, fcache) = newfeat
2 changes: 1 addition & 1 deletion src/transforms/center.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Center() = Center(AllSelector())
Center(cols) = Center(selector(cols))
Center(cols::C...) where {C<:Column} = Center(selector(cols))

assertions(transform::Center) = [SciTypeAssertion{Continuous}(transform.selector)]
assertions(transform::Center) = [scitypeassert(Continuous, transform.selector)]

isrevertible(::Type{<:Center}) = true

Expand Down
2 changes: 1 addition & 1 deletion src/transforms/closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct Closure <: StatelessFeatureTransform end

isrevertible(::Type{Closure}) = true

assertions(::Closure) = [SciTypeAssertion{Continuous}()]
assertions(::Closure) = [scitypeassert(Continuous)]

function applyfeat(::Closure, feat, prep)
cols = Tables.columns(feat)
Expand Down
2 changes: 1 addition & 1 deletion src/transforms/eigenanalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end

EigenAnalysis(proj; maxdim=nothing, pratio=1.0) = EigenAnalysis(proj, maxdim, pratio)

assertions(::EigenAnalysis) = [SciTypeAssertion{Continuous}()]
assertions(::EigenAnalysis) = [scitypeassert(Continuous)]

isrevertible(::Type{EigenAnalysis}) = true

Expand Down
2 changes: 1 addition & 1 deletion src/transforms/indicator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end

Indicator(col::Column; k=10, scale=:quantile, categ=false) = Indicator(selector(col), k, scale, categ)

assertions(transform::Indicator) = [SciTypeAssertion{Continuous}(transform.selector)]
assertions(transform::Indicator) = [scitypeassert(Continuous, transform.selector)]

isrevertible(::Type{<:Indicator}) = true

Expand Down
2 changes: 1 addition & 1 deletion src/transforms/levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Levels(pairs::Pair{C}...; ordered=nothing) where {C<:Column} =

Levels(; kwargs...) = throw(ArgumentError("cannot create Levels transform without arguments"))

assertions(transform::Levels) = [SciTypeAssertion{Categorical}(transform.selector)]
assertions(transform::Levels) = [scitypeassert(Categorical, transform.selector)]

isrevertible(::Type{<:Levels}) = true

Expand Down
2 changes: 1 addition & 1 deletion src/transforms/logratio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function revertmatrix end

isrevertible(::Type{<:LogRatio}) = true

assertions(::LogRatio) = [SciTypeAssertion{Continuous}()]
assertions(::LogRatio) = [scitypeassert(Continuous)]

function applyfeat(transform::LogRatio, feat, prep)
cols = Tables.columns(feat)
Expand Down
2 changes: 1 addition & 1 deletion src/transforms/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end

OneHot(col::Column; categ=false) = OneHot(selector(col), categ)

assertions(transform::OneHot) = [SciTypeAssertion{Categorical}(transform.selector)]
assertions(transform::OneHot) = [scitypeassert(Categorical, transform.selector)]

isrevertible(::Type{<:OneHot}) = true

Expand Down
2 changes: 1 addition & 1 deletion src/transforms/projectionpursuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ end
ProjectionPursuit(; tol=1e-6, maxiter=100, deg=5, perc=0.9, n=100, rng=Random.GLOBAL_RNG) =
ProjectionPursuit{typeof(tol),typeof(rng)}(tol, maxiter, deg, perc, n, rng)

assertions(::ProjectionPursuit) = [SciTypeAssertion{Continuous}()]
assertions(::ProjectionPursuit) = [scitypeassert(Continuous)]

isrevertible(::Type{<:ProjectionPursuit}) = true

Expand Down
2 changes: 1 addition & 1 deletion src/transforms/quantile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Quantile(; dist=Normal()) = Quantile(AllSelector(), dist)
Quantile(cols; dist=Normal()) = Quantile(selector(cols), dist)
Quantile(cols::C...; dist=Normal()) where {C<:Column} = Quantile(selector(cols), dist)

assertions(transform::Quantile) = [SciTypeAssertion{Continuous}(transform.selector)]
assertions(transform::Quantile) = [scitypeassert(Continuous, transform.selector)]

isrevertible(::Type{<:Quantile}) = true

Expand Down
2 changes: 1 addition & 1 deletion src/transforms/remainder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Remainder() = Remainder(nothing)

isrevertible(::Type{<:Remainder}) = true

assertions(::Remainder) = [SciTypeAssertion{Continuous}()]
assertions(::Remainder) = [scitypeassert(Continuous)]

function applyfeat(transform::Remainder, feat, prep)
cols = Tables.columns(feat)
Expand Down
2 changes: 1 addition & 1 deletion src/transforms/scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Scale(; low=0.25, high=0.75) = Scale(AllSelector(), low, high)
Scale(cols; low=0.25, high=0.75) = Scale(selector(cols), low, high)
Scale(cols::C...; low=0.25, high=0.75) where {C<:Column} = Scale(selector(cols), low, high)

assertions(transform::Scale) = [SciTypeAssertion{Continuous}(transform.selector)]
assertions(transform::Scale) = [scitypeassert(Continuous, transform.selector)]

isrevertible(::Type{<:Scale}) = true

Expand Down
10 changes: 10 additions & 0 deletions src/transforms/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,13 @@
zscore(x, μ, σ) = @. (x - μ) / σ

revzscore(y, μ, σ) = @. σ * y + μ

_assert(cond, msg) = cond || throw(AssertionError(msg))

function scitypeassert(S, selector=AllSelector())
Assert(
selector,
cond=x -> elscitype(x) <: S,
msg=nm -> "the elements of the column '$nm' are not of scientific type $S"
)
end
2 changes: 1 addition & 1 deletion src/transforms/zscore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ ZScore() = ZScore(AllSelector())
ZScore(cols) = ZScore(selector(cols))
ZScore(cols::C...) where {C<:Column} = ZScore(selector(cols))

assertions(transform::ZScore) = [SciTypeAssertion{Continuous}(transform.selector)]
assertions(transform::ZScore) = [scitypeassert(Continuous, transform.selector)]

isrevertible(::Type{<:ZScore}) = true

Expand Down
30 changes: 0 additions & 30 deletions test/assertions.jl

This file was deleted.

3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ Polynomial(args::T...) where {T<:Real} = Polynomial(collect(args))
include("metatable.jl")

# list of tests
testfiles =
["distributions.jl", "assertions.jl", "tableselection.jl", "tablerows.jl", "transforms.jl", "metadata.jl", "shows.jl"]
testfiles = ["distributions.jl", "tableselection.jl", "tablerows.jl", "transforms.jl", "metadata.jl", "shows.jl"]

@testset "TableTransforms.jl" begin
for testfile in testfiles
Expand Down
16 changes: 16 additions & 0 deletions test/shows.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
@testset "Shows" begin
@testset "Assert" begin
T = Assert(:a, :b, :c, cond=allunique)

# compact mode
iostr = sprint(show, T)
@test iostr == "Assert([:a, :b, :c], allunique, \"\")"

# full mode
iostr = sprint(show, MIME("text/plain"), T)
@test iostr == """
Assert transform
├─ selector = [:a, :b, :c]
├─ cond = allunique
└─ msg = \"\""""
end

@testset "Select" begin
T = Select(:a, :b, :c)

Expand Down
1 change: 1 addition & 0 deletions test/transforms.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
transformfiles = [
"assert.jl",
"select.jl",
"rename.jl",
"satisfies.jl",
Expand Down
43 changes: 43 additions & 0 deletions test/transforms/assert.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
@testset "Assert" begin
@test isrevertible(Assert(cond=allunique))

a = [1, 2, 3, 4, 5, 6]
b = [6, 5, 4, 3, 2, 1]
c = [1, 2, 3, 4, 6, 6]
d = [6, 6, 4, 3, 2, 1]
t = Table(; a, b, c, d)

T = Assert(1, 2, cond=allunique)
n, c = apply(T, t)
@test n == t
tₒ = revert(T, n, c)
@test tₒ == n == t
T = Assert(1, 2, 3, cond=allunique)
@test_throws AssertionError apply(T, t)

T = Assert([:c, :d], cond=x -> sum(x) > 21)
n, c = apply(T, t)
@test n == t
tₒ = revert(T, n, c)
@test tₒ == n == t
T = Assert([:b, :c, :d], cond=x -> sum(x) > 21)
@test_throws AssertionError apply(T, t)

T = Assert(("a", "b"), cond=allunique)
n, c = apply(T, t)
@test n == t
tₒ = revert(T, n, c)
@test tₒ == n == t
T = Assert(("a", "b", "c"), cond=allunique, msg="assertion error")
@test_throws AssertionError apply(T, t)
@test_throws "assertion error" apply(T, t)

T = Assert(r"[cd]", cond=x -> sum(x) > 21)
n, c = apply(T, t)
@test n == t
tₒ = revert(T, n, c)
@test tₒ == n == t
T = Assert(r"[bcd]", cond=x -> sum(x) > 21, msg=nm -> "error in column $nm")
@test_throws AssertionError apply(T, t)
@test_throws "error in column b" apply(T, t)
end
Loading