Skip to content

Commit

Permalink
style: move some of the functions (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
oameye authored Oct 27, 2024
1 parent 19292ce commit 8715ee5
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 133 deletions.
8 changes: 8 additions & 0 deletions src/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,11 @@ function filter_result!(res::Result, class::String)
res.classes[c] = [s[bools] for s in res.classes[c]]
end
end

function _classify_default!(result)
classify_solutions!(result, _is_physical, "physical")
classify_solutions!(result, _is_stable(result), "stable")
classify_solutions!(result, _is_Hopf_unstable(result), "Hopf")
order_branches!(result, ["physical", "stable"]) # shuffle the branches to have relevant ones first
return classify_binaries!(result) # assign binaries to solutions depending on which branches are stable
end
39 changes: 0 additions & 39 deletions src/plotting_Plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,6 @@ Similar to `plot` but adds a plot onto an existing plot.
function Plots.plot!(res::Result, varargs...; kwargs...)::Plots.Plot
return plot(res, varargs...; add=true, _set_Plots_default..., kwargs...)
end
"""
$(TYPEDSIGNATURES)
Return an array of bools to mark solutions in `res` which fall into `classes` but not `not_classes`.
Only `branches` are considered.
"""
function _get_mask(res, classes, not_classes=[]; branches=1:branch_count(res))
classes == "all" && return fill(trues(length(branches)), size(res.solutions))
bools = vcat(
[res.classes[c] for c in _str_to_vec(classes)],
[map(.!, res.classes[c]) for c in _str_to_vec(not_classes)],
)
#m = map( x -> [getindex(x, b) for b in [branches...]], map(.*, bools...))

return m = map(x -> x[[branches...]], map(.*, bools...))
end

"""
$(TYPEDSIGNATURES)
Go over a solution and an equally-sized array (a "mask") of booleans.
true -> solution unchanged
false -> changed to NaN (omitted from plotting)
"""
function _apply_mask(solns::Array{Vector{ComplexF64}}, booleans)
factors = replace.(booleans, 0 => NaN)
return map(.*, solns, factors)
end
function _apply_mask(solns::Vector{Vector{Vector{ComplexF64}}}, booleans)
Nan_vector = NaN .* similar(solns[1][1])
new_solns = [
[booleans[i][j] ? solns[i][j] : Nan_vector for j in eachindex(solns[i])] for
i in eachindex(solns)
]
return new_solns
end

""" Project the array `a` into the real axis, warning if its contents are complex. """
function _realify(a::Array{T} where {T<:Number}; warning="")
Expand All @@ -119,9 +83,6 @@ function _realify(a::Array{T} where {T<:Number}; warning="")
return a_real
end

_str_to_vec(s::Vector) = s
_str_to_vec(s) = [s]

# return true if p already has a label for branch index idx
function _is_labeled(p::Plot, idx::Int64)
return in(string(idx), [sub[:label] for sub in p.series_list])
Expand Down
92 changes: 0 additions & 92 deletions src/solve_homotopy.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,3 @@
# assume this order of variables in all compiled function (transform_solutions, Jacobians)
function _free_symbols(res::Result)
return cat(res.problem.variables, collect(keys(res.swept_parameters)); dims=1)
end
function _free_symbols(p::Problem, varied)
return cat(p.variables, collect(keys(OrderedDict(varied))); dims=1)
end
_symidx(sym::Num, args...) = findfirst(x -> isequal(x, sym), _free_symbols(args...))

"""
$(TYPEDSIGNATURES)
Return an ordered dictionary specifying all variables and parameters of the solution
in `result` on `branch` at the position `index`.
"""
function get_single_solution(res::Result; branch::Int64, index)::OrderedDict{Num,ComplexF64}

# check if the dimensionality of index matches the solutions
if length(size(res.solutions)) !== length(index)
# if index is a number, use linear indexing
index = if length(index) == 1
CartesianIndices(res.solutions)[index]
else
error("Index ", index, " undefined for a solution of size ", size(res.solutions))
end
else
index = CartesianIndex(index)
end

vars = OrderedDict(zip(res.problem.variables, res.solutions[index][branch]))

# collect the swept parameters required for this call
swept_params = OrderedDict(
key => res.swept_parameters[key][index[i]] for
(i, key) in enumerate(keys(res.swept_parameters))
)
full_solution = merge(vars, swept_params, res.fixed_parameters)

return OrderedDict(zip(keys(full_solution), ComplexF64.(values(full_solution))))
end

function get_single_solution(res::Result, index)
return [
get_single_solution(res; index=index, branch=b) for b in 1:length(res.solutions[1])
]
end

"""
get_steady_states(prob::Problem,
swept_parameters::ParameterRange,
Expand Down Expand Up @@ -162,14 +116,6 @@ function get_steady_states(
return result
end

function _classify_default!(result)
classify_solutions!(result, _is_physical, "physical")
classify_solutions!(result, _is_stable(result), "stable")
classify_solutions!(result, _is_Hopf_unstable(result), "Hopf")
order_branches!(result, ["physical", "stable"]) # shuffle the branches to have relevant ones first
return classify_binaries!(result) # assign binaries to solutions depending on which branches are stable
end

function get_steady_states(p::Problem, swept, fixed; kwargs...)
return get_steady_states(p, ParameterRange(swept), ParameterList(fixed); kwargs...)
end
Expand Down Expand Up @@ -218,33 +164,6 @@ function compile_matrix(mat, variables; rules=Dict(), postproc=x -> x)
return m
end

"Find a branch order according `classification`. Place branches where true occurs earlier first."
function find_branch_order(classification::Vector{BitVector})
branches = [getindex.(classification, k) for k in 1:length(classification[1])] # array of branches
indices = replace(findfirst.(branches), nothing => Inf)
negative = findall(x -> x == Inf, indices) # branches not true anywhere - leave out
return order = setdiff(sortperm(indices), negative)
end

find_branch_order(classification::Array) = collect(1:length(classification[1])) # no ordering for >1D

"Order the solution branches in `res` such that close classified positively by `classes` are first."
function order_branches!(res::Result, classes::Vector{String})
for class in classes
order_branches!(res, find_branch_order(res.classes[class]))
end
end

order_branches!(res::Result, class::String) = order_branches!(res, [class])

"Reorder the solutions in `res` to match the index permutation `order`."
function order_branches!(res::Result, order::Vector{Int64})
res.solutions = _reorder_nested(res.solutions, order)
for key in keys(res.classes)
res.classes[key] = _reorder_nested(res.classes[key], order)
end
end

"Reorder EACH ELEMENT of `a` to match the index permutation `order`. If length(order) < length(array), the remanining positions are kept."
function _reorder_nested(a::Array, order::Vector{Int64})
a[1] isa Union{Array,BitVector} || return a
Expand Down Expand Up @@ -387,8 +306,6 @@ function pad_solutions(solutions::Array{Vector{Vector{ComplexF64}}}; padding_val
return padded_solutions
end

tuple_to_vector(t::Tuple) = [i for i in t]

function newton(prob::Problem, soln::OrderedDict)
vars = _convert_or_zero.(substitute_all(prob.variables, soln))
pars = _convert_or_zero.(substitute_all(prob.parameters, soln))
Expand All @@ -405,12 +322,3 @@ Any variables/parameters not present in `soln` are set to zero.
"""
newton(res::Result, soln::OrderedDict) = newton(res.problem, soln)
newton(res::Result; branch, index) = newton(res, res[index][branch])

function _convert_or_zero(x, t=ComplexF64)
try
convert(t, x)
catch ArgumentError
@warn string(x) * " not supplied: setting to zero"
return 0
end
end
27 changes: 27 additions & 0 deletions src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,30 @@ function sort_2D(
end
return sorted_solns
end

"Find a branch order according `classification`. Place branches where true occurs earlier first."
function find_branch_order(classification::Vector{BitVector})
branches = [getindex.(classification, k) for k in 1:length(classification[1])] # array of branches
indices = replace(findfirst.(branches), nothing => Inf)
negative = findall(x -> x == Inf, indices) # branches not true anywhere - leave out
return order = setdiff(sortperm(indices), negative)
end

find_branch_order(classification::Array) = collect(1:length(classification[1])) # no ordering for >1D

"Order the solution branches in `res` such that close classified positively by `classes` are first."
function order_branches!(res::Result, classes::Vector{String})
for class in classes
order_branches!(res, find_branch_order(res.classes[class]))
end
end

order_branches!(res::Result, class::String) = order_branches!(res, [class])

"Reorder the solutions in `res` to match the index permutation `order`."
function order_branches!(res::Result, order::Vector{Int64})
res.solutions = _reorder_nested(res.solutions, order)
for key in keys(res.classes)
res.classes[key] = _reorder_nested(res.classes[key], order)
end
end
74 changes: 72 additions & 2 deletions src/transform_solutions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,39 @@
_parse_expression(exp) = exp isa String ? Num(eval(Meta.parse(exp))) : exp
"""
$(TYPEDSIGNATURES)
Return an ordered dictionary specifying all variables and parameters of the solution
in `result` on `branch` at the position `index`.
"""
function get_single_solution(res::Result; branch::Int64, index)::OrderedDict{Num,ComplexF64}

# check if the dimensionality of index matches the solutions
if length(size(res.solutions)) !== length(index)
# if index is a number, use linear indexing
index = if length(index) == 1
CartesianIndices(res.solutions)[index]
else
error("Index ", index, " undefined for a solution of size ", size(res.solutions))
end
else
index = CartesianIndex(index)
end

vars = OrderedDict(zip(res.problem.variables, res.solutions[index][branch]))

# collect the swept parameters required for this call
swept_params = OrderedDict(
key => res.swept_parameters[key][index[i]] for
(i, key) in enumerate(keys(res.swept_parameters))
)
full_solution = merge(vars, swept_params, res.fixed_parameters)

return OrderedDict(zip(keys(full_solution), ComplexF64.(values(full_solution))))
end

function get_single_solution(res::Result, index)
return [
get_single_solution(res; index=index, branch=b) for b in 1:length(res.solutions[1])
]
end

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -90,7 +125,42 @@ function _similar(type, res::Result; branches=1:branch_count(res))
return [type(undef, length(branches)) for k in res.solutions]
end

## TODO move masks here
"""
$(TYPEDSIGNATURES)
Return an array of bools to mark solutions in `res` which fall into `classes` but not `not_classes`.
Only `branches` are considered.
"""
function _get_mask(res, classes, not_classes=[]; branches=1:branch_count(res))
classes == "all" && return fill(trues(length(branches)), size(res.solutions))
bools = vcat(
[res.classes[c] for c in _str_to_vec(classes)],
[map(.!, res.classes[c]) for c in _str_to_vec(not_classes)],
)
#m = map( x -> [getindex(x, b) for b in [branches...]], map(.*, bools...))

return m = map(x -> x[[branches...]], map(.*, bools...))
end

"""
$(TYPEDSIGNATURES)
Go over a solution and an equally-sized array (a "mask") of booleans.
true -> solution unchanged
false -> changed to NaN (omitted from plotting)
"""
function _apply_mask(solns::Array{Vector{ComplexF64}}, booleans)
factors = replace.(booleans, 0 => NaN)
return map(.*, solns, factors)
end
function _apply_mask(solns::Vector{Vector{Vector{ComplexF64}}}, booleans)
Nan_vector = NaN .* similar(solns[1][1])
new_solns = [
[booleans[i][j] ? solns[i][j] : Nan_vector for j in eachindex(solns[i])] for
i in eachindex(solns)
]
return new_solns
end

###
# TRANSFORMATIONS TO THE LAB frame
Expand Down
10 changes: 10 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ function Base.show(io::IO, p::Problem)
return println(io, "Symbolic Jacobian: ", !(p.jacobian == false))
end

# assume this order of variables in all compiled function (transform_solutions, Jacobians)
function _free_symbols(p::Problem, varied)
return cat(p.variables, collect(keys(OrderedDict(varied))); dims=1)
end

"""
$(TYPEDEF)
Expand Down Expand Up @@ -263,6 +268,11 @@ function Base.show(io::IO, r::Result)
return println(io, "\nClasses: ", join(keys(r.classes), ", "))
end

# assume this order of variables in all compiled function (transform_solutions, Jacobians)
function _free_symbols(res::Result)
return cat(res.problem.variables, collect(keys(res.swept_parameters)); dims=1)
end

# overload to use [] for indexing
Base.getindex(r::Result, idx::Int...) = get_single_solution(r, idx)
Base.size(r::Result) = size(r.solutions)
Expand Down
16 changes: 16 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,19 @@ is_real(x) = abs(imag(x)) / abs(real(x)) < IM_TOL::Float64 || abs(x) < 1e-70
is_real(x::Array) = is_real.(x)

flatten(a) = collect(Iterators.flatten(a))

_parse_expression(exp) = exp isa String ? Num(eval(Meta.parse(exp))) : exp
_symidx(sym::Num, args...) = findfirst(x -> isequal(x, sym), _free_symbols(args...))

tuple_to_vector(t::Tuple) = [i for i in t]
_str_to_vec(s::Vector) = s
_str_to_vec(s) = [s]

function _convert_or_zero(x, t=ComplexF64)
try
convert(t, x)
catch ArgumentError
@warn string(x) * " not supplied: setting to zero"
return 0
end
end

0 comments on commit 8715ee5

Please sign in to comment.