-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ComponentArray
instead of NamedArray
for mode estimation results?
#2363
Comments
Thanks for raising this. I agree that the return value of mode estimation could and should be improved. ComponentArrays would be a good candidate. There are some other elements of the return value that we would also like to change, mentioned in #2232. Putting this on our todo list, though unsure when we'll get to it. |
If there's interest in the meantime, I'd be willing to make a PR for this feature. It doesn't look like it requires many code changes and would be a huge increase in usability, at least for me... |
The issue with ComponentArrays.jl is that it effectively takes a @model function demo()
x = Vector(undef, 100)
for i in eachindex(x)
x[i] ~ Normal()
end
end If you "naively" convert the variables here as they occur in the model, In short, they each have respective pros and cons, and we should probably use a "customized" solution for this. An alternative route @mhauru is to maybe just use |
I'm not sure I quite understand--the whole point would be to get that array Are you saying that in a model like this, @model function demo()
a ~ Normal()
b ~ Normal()
...
z ~ Normal()
a1 ~ Normal()
...
z1 ~ Normal()
...
a100 ~ Normal()
...
z100 ~ Normal()
end where there are many, many individually named variables (2,600 in this case) the performance of a Ultimately, it doesn't actually matter whether this hypothetical return structure is a |
Yep. But a more likely scenario is the model I mentioned above, i.e. model function demo()
x = Vector(undef, 100)
for i in eachindex(x)
x[i] ~ Normal()
end
end The problem is that model function demo()
x = Matrix(undef, 1, 100)
for i in eachindex(x)
x[i] ~ Normal()
end
end because That is,
Buuuut we can make this experience quite a bit less painful, which I very much agree with you @ElOceanografo is something we should do. And one way to do this is to use our custom structure |
Thanks, that example clarifies the issue. Even getting everything back as a flat vector would be an improvement (and I bet would cover most use cases, vectors are more common in stats models than arrays). If My only other comment is that this interface, whatever it ends up being, should ideally be consistent with the interface for MCMC chains. In fact, a function Turing.group(opt::Turing.Optimisation.ModeResult, varname)
basenames = first.(split.(string.(names(opt.values)[1]), "["))
idx = findall(x -> x == string(varname), basenames)
return opt.values[idx]
end I also was using this function in a script recently: function named2component(v::Turing.NamedArrays.NamedVector)
v = opt.values
symbols = names(v)[1]
strings = first.(split.(string.(symbols), "["))
uniquestrings = unique(strings)
nt = NamedTuple([Symbol(s) => vec(v[strings .== s]) for s in uniquestrings])
return ComponentArray(nt)
end These will both flatten 2D or higher-dimensional arrays, of course. Still, might be useful to other people reading this... |
When doing an MLE or MAP optimization, the resulting
ModeResult.values
is aNamedArray
, from NamedArrays.jl. If the model contains any array-valued parameters, this can be a real pain to work with: each individual element is indexed with its own unique symbol, requiring the user to do some annoying data munging to get them all together. For instance, to extract a matrix calledx
that is supposed to have size 20 x 10:Maybe there's a better way to do this that I don't know about, but wouldn't it be simpler to return the results as a
ComponentArray
, a data structure designed for just this use case? That would let you simply do:Is there any compelling reason to prefer a
NamedArray
here?The text was updated successfully, but these errors were encountered: