diff --git a/Manifest.toml b/Manifest.toml index fae5b2d..e404bfd 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,20 +2,44 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "be2f32a4722249e4df142631a0bfb91c4fb610ed" +project_hash = "20331d536f31ba7b0653e4e759f698167642aebc" + +[[deps.ADTypes]] +git-tree-sha1 = "016833eb52ba2d6bea9fcb50ca295980e728ee24" +uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +version = "0.2.7" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] +git-tree-sha1 = "c0d491ef0b135fd7d63cbc6404286bc633329425" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.36" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "4.0.4" +weakdeps = ["StaticArrays"] [deps.Adapt.extensions] AdaptStaticArraysExt = "StaticArrays" - [deps.Adapt.weakdeps] - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" @@ -70,6 +94,16 @@ git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" version = "0.7.4" +[[deps.Combinatorics]] +git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.2" + +[[deps.CommonSolve]] +git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" +uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +version = "0.2.4" + [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" @@ -91,6 +125,20 @@ deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "1.1.0+0" +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConcreteStructs]] +git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" +uuid = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +version = "0.2.3" + [[deps.ConstructionBase]] deps = ["LinearAlgebra"] git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" @@ -136,6 +184,12 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + [[deps.DiffResults]] deps = ["StaticArraysCore"] git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" @@ -185,6 +239,16 @@ git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" +[[deps.EnumX]] +git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" +uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +version = "1.0.4" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + [[deps.FilePathsBase]] deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" @@ -227,12 +291,21 @@ deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "Lo git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.36" +weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] ForwardDiffStaticArraysExt = "StaticArrays" - [deps.ForwardDiff.weakdeps] - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.FunctionWrappersWrappers]] +deps = ["FunctionWrappers"] +git-tree-sha1 = "b104d487b34566608f8b4e1c39fb0b10aa279ff8" +uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf" +version = "0.1.3" [[deps.Future]] deps = ["Random"] @@ -244,6 +317,18 @@ git-tree-sha1 = "273bd1cd30768a2fddfa3fd63bbc746ed7249e5f" uuid = "38e38edf-8417-5370-95a0-9cbb8c7f171a" version = "1.9.0" +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.HCubature]] +deps = ["Combinatorics", "DataStructures", "LinearAlgebra", "QuadGK", "StaticArrays"] +git-tree-sha1 = "10f37537bbd83e52c63abf6393f209dbd641fedc" +uuid = "19dc6840-f33b-545b-b366-655c7e3ffd49" +version = "1.6.0" + [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" @@ -256,10 +341,50 @@ git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" version = "1.4.0" +[[deps.IntegerMathUtils]] +git-tree-sha1 = "b8ffb903da9f7b8cf695a8bead8e01814aa24b30" +uuid = "18e54dd8-cb9d-406c-a71d-865a43cbb235" +version = "0.1.2" + +[[deps.Integrals]] +deps = ["CommonSolve", "HCubature", "LinearAlgebra", "MonteCarloIntegration", "QuadGK", "Reexport", "SciMLBase"] +git-tree-sha1 = "ebf5737d823873add85809f2b52e20e3eae71997" +uuid = "de52edbc-65ea-441a-8357-d3a637375a31" +version = "4.4.1" + + [deps.Integrals.extensions] + IntegralsArblibExt = "Arblib" + IntegralsCubaExt = "Cuba" + IntegralsCubatureExt = "Cubature" + IntegralsFastGaussQuadratureExt = "FastGaussQuadrature" + IntegralsForwardDiffExt = "ForwardDiff" + IntegralsMCIntegrationExt = "MCIntegration" + IntegralsZygoteExt = ["Zygote", "ChainRulesCore"] + + [deps.Integrals.weakdeps] + Arblib = "fb37089c-8514-4489-9461-98f9c8763369" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1" + Cubature = "667455a9-e2ce-5579-9412-b964f529a492" + FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "896385798a8d49a255c398bd49162062e4a4c435" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.13" +weakdeps = ["Dates"] + + [deps.InverseFunctions.extensions] + DatesExt = "Dates" + [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" @@ -286,6 +411,12 @@ git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" version = "1.3.1" +[[deps.LatticeRules]] +deps = ["Random"] +git-tree-sha1 = "7f5b02258a3ca0221a6a9710b0a0a2e8fb4957fe" +uuid = "73f95e8e-ec14-4e6a-8b18-0d2e271c4e55" +version = "0.0.1" + [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -366,6 +497,12 @@ version = "1.1.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[deps.MonteCarloIntegration]] +deps = ["Distributions", "QuasiMonteCarlo", "Random"] +git-tree-sha1 = "722ad522068d31954b4a976b66a26aeccbf509ed" +uuid = "4886b29c-78c9-11e9-0a6e-41e1f4161f7b" +version = "0.2.0" + [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" @@ -478,6 +615,12 @@ git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" version = "2.3.1" +[[deps.Primes]] +deps = ["IntegerMathUtils"] +git-tree-sha1 = "cb420f77dc474d23ee47ca8d14c90810cafe69e7" +uuid = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" +version = "0.5.6" + [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -488,6 +631,16 @@ git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" version = "2.9.4" +[[deps.QuasiMonteCarlo]] +deps = ["Accessors", "ConcreteStructs", "LatticeRules", "LinearAlgebra", "Primes", "Random", "Requires", "Sobol", "StatsBase"] +git-tree-sha1 = "cc086f8485bce77b6187141e1413c3b55f9a4341" +uuid = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" +version = "0.3.3" +weakdeps = ["Distributions"] + + [deps.QuasiMonteCarlo.extensions] + QuasiMonteCarloDistributionsExt = "Distributions" + [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -496,6 +649,36 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.RecursiveArrayTools]] +deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "SparseArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "d8f131090f2e44b145084928856a561c83f43b27" +uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" +version = "3.13.0" + + [deps.RecursiveArrayTools.extensions] + RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" + RecursiveArrayToolsForwardDiffExt = "ForwardDiff" + RecursiveArrayToolsMeasurementsExt = "Measurements" + RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" + RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] + RecursiveArrayToolsTrackerExt = "Tracker" + RecursiveArrayToolsZygoteExt = "Zygote" + + [deps.RecursiveArrayTools.weakdeps] + FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.Reexport]] git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" @@ -519,10 +702,52 @@ git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" version = "0.4.0+0" +[[deps.RuntimeGeneratedFunctions]] +deps = ["ExprTools", "SHA", "Serialization"] +git-tree-sha1 = "04c968137612c4a5629fa531334bb81ad5680f00" +uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +version = "0.5.13" + [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.SciMLBase]] +deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "816176bca8a93f8f50a33853e0933d6c4ec116d0" +uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +version = "2.33.1" + + [deps.SciMLBase.extensions] + SciMLBaseChainRulesCoreExt = "ChainRulesCore" + SciMLBaseMakieExt = "Makie" + SciMLBasePartialFunctionsExt = "PartialFunctions" + SciMLBasePyCallExt = "PyCall" + SciMLBasePythonCallExt = "PythonCall" + SciMLBaseRCallExt = "RCall" + SciMLBaseZygoteExt = "Zygote" + + [deps.SciMLBase.weakdeps] + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" + PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" + PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" + PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" + RCall = "6f49c342-dc21-5d91-9882-a32aef131414" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.SciMLOperators]] +deps = ["ArrayInterface", "DocStringExtensions", "LinearAlgebra", "MacroTools", "Setfield", "SparseArrays", "StaticArraysCore"] +git-tree-sha1 = "10499f619ef6e890f3f4a38914481cc868689cd5" +uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +version = "0.3.8" + +[[deps.SciMLStructures]] +git-tree-sha1 = "5833c10ce83d690c124beedfe5f621b50b02ba4d" +uuid = "53ae85a6-f571-4167-b2af-e1d143709226" +version = "1.1.0" + [[deps.SentinelArrays]] deps = ["Dates", "Random"] git-tree-sha1 = "0e7508ff27ba32f26cd459474ca2ede1bc10991f" @@ -543,6 +768,12 @@ git-tree-sha1 = "503688b59397b3307443af35cd953a13e8005c16" uuid = "1277b4bf-5013-50f5-be3d-901d8477a67a" version = "2.0.0" +[[deps.Sobol]] +deps = ["DelimitedFiles", "Random"] +git-tree-sha1 = "5a74ac22a9daef23705f010f72c81d6925b19df8" +uuid = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4" +version = "1.5.0" + [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -569,6 +800,20 @@ version = "2.3.1" [deps.SpecialFunctions.weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.3" + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + + [deps.StaticArrays.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [[deps.StaticArraysCore]] git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -626,6 +871,12 @@ deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" version = "7.2.1+1" +[[deps.SymbolicIndexingInterface]] +deps = ["Accessors", "ArrayInterface", "MacroTools", "RuntimeGeneratedFunctions", "StaticArraysCore"] +git-tree-sha1 = "40ea524431a92328cd73582d1820a5b08247a40f" +uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +version = "0.3.16" + [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" diff --git a/Project.toml b/Project.toml index eef393e..99ae7d4 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/src/UnobservedCountEstimation.jl b/src/UnobservedCountEstimation.jl index e00d34b..91e38b3 100644 --- a/src/UnobservedCountEstimation.jl +++ b/src/UnobservedCountEstimation.jl @@ -4,11 +4,12 @@ module UnobservedCountEstimation # Imports using Optim, Statistics, GLM, Distributions, DataFrames, SpecialFunctions, LinearAlgebra +using Integrals include("zhang_likelihood.jl") include("original_model.jl") -include("binomial_likelihood.jl") +include("binomial_model_sampling.jl") include("binomial_model.jl") include("interface.jl") diff --git a/src/binomial_likelihood.jl b/src/binomial_likelihood.jl deleted file mode 100644 index 76f1b52..0000000 --- a/src/binomial_likelihood.jl +++ /dev/null @@ -1,75 +0,0 @@ -# likelihood and derivatives go here - -function log_lik_binomial_model(M, α, β, m, N, n, X, Z) - μ = (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) - M = exp.(M) .+ m - μ = exp.(μ) ./ (1 .+ exp.(μ)) - p = μ ./ M - - sum(loggamma.(M .+ 1) .- loggamma.(m .+ 1.0) .- loggamma.(M .- m .+ 1.0) .+ m .* log.(p) .+ (M .- m) .* log.(1 .- p)) -end # end function - -function grad_log_lik_binomial_model(M, α, β, m, N, n, X, Z) - μ = (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) - Mprev = copy(M) - M = exp.(M) .+ m - μ = exp.(μ) ./ (1 .+ exp.(μ)) - p = μ ./ M - - dp = m ./ p .- (M .- m) ./ (1 .- p) - - dα = dp .* (1 .- μ) .* μ .* (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* log.(N) ./ M - dβ = dp .* (1 .- μ) .* μ .* (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* log.(n ./ N) ./ M - - dM = digamma.(M .+ 1) .- digamma.(M .- m .+ 1) .- (m ./ M) .+ (M .- m) .* ((1 .- p) .^ -1) .* (p ./ M) .+ log.(1 .- p) - # TODO:: derivative correction M, exp link - - vcat(dM .* exp.(Mprev), X' * dα, Z' * dβ) -end # end function - -function hess_log_lik_binomial_model(M, α, β, m, N, n, X, Z) - μ = (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) - μ = exp.(μ) ./ (1 .+ exp.(μ)) - Mprev = copy(M) - M = exp.(M) .+ m - p = μ ./ M - - dp = m ./ p .- (M .- m) ./ (1 .- p) - dp_2 = -m ./ p .^ 2 .- (M .- m) ./ (1 .- p) .^ 2 - - dα_2 = -2 .* μ .^ 2 .* (1 .- μ) .* (log.(N) .^ 2) .* (N .^ (2 * X * α)) .* ((n ./ N) .^ (2 * Z * β)) - dα_2 += (log.(N) .^ 2) .* (N .^ (2 * X * α)) .* ((n ./ N) .^ (2 * Z * β)) .* μ .* (1 .- μ) - dα_2 += (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* (log.(N) .^ 2) .* μ .* (1 .- μ) - dα_2 .*= dp ./ M - dα_2 += dp_2 .* ((1 .- μ) .* μ .* (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* log.(N) ./ M) .^ 2 - - dβ_2 = -2 .* μ .^ 2 .* (1 .- μ) .* (log.(n ./ N) .^ 2) .* (N .^ (2 * X * α)) .* ((n ./ N) .^ (2 * Z * β)) - dβ_2 += (log.(n ./ N) .^ 2) .* (N .^ (2 * X * α)) .* ((n ./ N) .^ (2 * Z * β)) .* μ .* (1 .- μ) - dβ_2 += (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* (log.(n ./ N) .^ 2) .* μ .* (1 .- μ) - dβ_2 .*= dp ./ M - dβ_2 += dp_2 .* ((1 .- μ) .* μ .* (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* log.(N) ./ M) .^ 2 - - dαdβ = dp .* (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* log.(N) .* log.(n ./ N) .* μ .* (1 .- μ) ./ M - dαdβ .*= ((N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* (exp.((N .^ (X * α)) .* ((n ./ N) .^ (Z * β))) .- 1) - exp.((N .^ (X * α)) .* ((n ./ N) .^ (Z * β))) .- 1) - dαdβ += dp_2 .* ((1 .- μ) .* μ) .^ 2 .* (N .^ (2 * X * α)) .* ((n ./ N) .^ (2 * Z * β)) .* log.(N) .* log.(n ./ N) ./ (M .^ 2) - - dpdM = (1 .- m ./ M) .* ((1 .- p) .^ -2) .- (1 .- p) .^ -1 - dpdM .*= exp.(Mprev) - - dαdM = dpdM .* (1 .- μ) .* μ .* (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* log.(N) ./ M - dβdM = dpdM .* (1 .- μ) .* μ .* (N .^ (X * α)) .* ((n ./ N) .^ (Z * β)) .* log.(n ./ N) ./ M - - dM_2 = trigamma.(M .+ 1) .- trigamma.(M .- m .+ 1) .+ m ./ M .^ 2 - dM_2 += (m ./ M .^ 2) .* p ./ (1 .- p) .+ μ ./ ((1 .- p) .* M .^ 2) - dM_2 -= (1 .- m ./ M) .* (p ./ M) ./ (1 .- p) .^ 2 - dM = digamma.(M .+ 1) .- digamma.(M .- m .+ 1) .- (m ./ M) .+ (M .- m) .* ((1 .- p) .^ -1) .* (p ./ M) .+ log.(1 .- p) - dM_2 = dM_2 .* exp.(2 .* Mprev) .+ dM .* exp.(Mprev) - # TODO:: derivative correction M, exp link - - ## TODO if M is a vector then dM_2 is a diagonal matrix and dαdM - vcat( - hcat(Diagonal(dM_2[:, 1]), Diagonal(dαdM[:, 1]) * X, Diagonal(dβdM[:, 1]) * Z), - hcat(X' * Diagonal(dαdM[:, 1]), X' * (dα_2 .* X), (X' * (dαdβ .* Z))'), - hcat(Z' * Diagonal(dβdM[:, 1]), (X' * (dαdβ .* Z))', Z' * (dβ_2 .* Z)) - ) -end # end function \ No newline at end of file diff --git a/src/binomial_model.jl b/src/binomial_model.jl index 2e918dd..1f0be7d 100644 --- a/src/binomial_model.jl +++ b/src/binomial_model.jl @@ -1,4 +1,6 @@ -function binomial_model(m, N, n; start = "glm") +function binomial_model(m, N, n; start = "glm", iter = 2000, + warm_up = floor(Int, iter / 2), grid, + save_simulation = true) # TODO:: add X, Z arguments and then methods for type X/Z nothing or formula df = DataFrame( y = m, @@ -12,29 +14,15 @@ function binomial_model(m, N, n; start = "glm") Z = ones(length(n)) Z = Z[:, :] - log_l_f = x -> log_lik_binomial_model(x[1:(end - 2)], x[end - 1], x[end], m, N, n, X, Z) * (-1.0) - grad_l_f = x -> grad_log_lik_binomial_model(x[1:(end - 2)], x[end - 1], x[end], m, N, n, X, Z) * (-1.0) - hes_l_f = x -> hess_log_lik_binomial_model(x[1:(end - 2)], x[end - 1], x[end], m, N, n, X, Z) * (-1.0) - - #= result = optimize(log_l_f, [start[1], start[2], 1], Newton(); inplace = false) - result_1 = optimize(log_l_f, grad_l_f, [start[1], start[2], 1], Newton(); inplace = false) =# - # TODO :: dependent α, β - - start = Float64[] - append!(start, zeros(length(N))) + start = [M] if start == "glm" append!(start, coef(glm(@formula(y ~ x1 + x2 + 0), df, Poisson(), LogLink()))) else append!(start, coef(lm(@formula(log(y) ~ x1 + x2 + 0), df))) - end # end if + end # end + + res = gibbs_sampler_binomial_model(start, grid, iter, n, N, m) - optim_problem = optimize(log_l_f, grad_l_f, hes_l_f, start, NewtonTrustRegion(); inplace = false) - α̂ = optim_problem.minimizer[end - 1] - β̂ = optim_problem.minimizer[end] - M̂ = optim_problem.minimizer[1:length(N)] - ξ̂ = N .^ α̂ - #[coef(ols), coef(mm)] - #[start, log_l, grad_l, hes_l] - [[α̂, β̂, ξ̂, M̂, sum(ξ̂), sum(M̂)], optim_problem] + # return object with summary statistics and end # end function \ No newline at end of file diff --git a/src/binomial_model_sampling.jl b/src/binomial_model_sampling.jl new file mode 100644 index 0000000..1cf59aa --- /dev/null +++ b/src/binomial_model_sampling.jl @@ -0,0 +1,117 @@ +function sample_gamma_1_cond(grid, n, N, γ₂, M, m, μ_γ₁, μ_γ₂, ρ, σ_γ₁, σ_γ₂, ε = 1e-6) + # get posteriori normal parameters + μ_γ₁_post = μ_γ₁ + ρ * σ_γ₂ / σ_γ₁ * (γ₂ - μ_γ₂) + σ_γ₁_post = sqrt(1 - ρ ^ 2) * σ_γ₁ + distr = Normal(μ_γ₁, σ_γ₁) + # compute R_i's + #log_N_sq = log.(N) .^ 2 + #= R = zeros(BigFloat, length(M)) + next_iter = true + t = 0 + while next_iter + println(t) + # Rprev = copy(R) + # iterate sum + R_add = cgf.(distr, (M .- t) .* log.(N)) .- logfactorial(t) + R .+= exp.(R_add) .* (-1) ^ t + println(R_add) + # check convergence + next_iter = any(exp.(R_add) .< ε) | t > 1 + t += 1 + end # end while + println(R) =# + f(x, p) = exp.(x .* M .* log.(N) - N .^ x) .* pdf(distr, x) + prob = IntegralProblem(f, [-Inf, Inf]) + R = solve(prob, HCubatureJL(), reltol = ε, abstol = ε) + #println(R) + R .*= exp.(-logfactorial.(M)) # <--- if this fails this is probably why + #println(R) + #error("bcd") + + # shift grid towards the mean since it is most probable + grid1 = grid .+ μ_γ₁_post + # get unscaled density function and evaluate it on a grid + function density_function(x) + μ = (N .^ x) .* ((n ./ N) .^ γ₂) + μ = 1 ./ (1 .+ exp.(-μ)) + lξ = x .* log.(N) + + #res = (μ .^ m) .* ((1 .- μ) .^ (M .- m)) + res = zeros(BigFloat, length(M)) + res .+= m .* log.(μ) .+ (M .- m) .* log.(1 .- μ) .- log.(R) .- logfactorial.(M .- m) .- exp.(lξ) .+ M .* lξ + #println(exp(sum(res))) + exp(sum(res)) .* pdf(Normal(μ_γ₁_post, σ_γ₁_post), x) + end # end funciton + evaluated_denisty = density_function.(grid1) + #println(evaluated_denisty) + evaluated_denisty ./= sum(evaluated_denisty) + # sample acording to evaluation + grid1[rand(Categorical(evaluated_denisty))] +end # end funciton + +function sample_gamma_2_cond(grid, n, N, γ₁, M, m, μ_γ₁, μ_γ₂, ρ, σ_γ₁, σ_γ₂) + # get posteriori normal parameters + μ_γ₂_post = μ_γ₂ + ρ * σ_γ₁ / σ_γ₂ * (γ₁ - μ_γ₁) + σ_γ₂_post = sqrt(1 - ρ ^ 2) * σ_γ₂ + # shift grid towards the mean since it is most probable + grid1 = grid .- μ_γ₂_post + # get unscaled density function and evaluate it on a grid + function density_function(x) + μ = (N .^ γ₁) .* ((n ./ N) .^ x) + μ = 1 ./ (1 .+ exp.(-μ)) + #ξ = N .^ γ₁ + + res = logfactorial.(M) .- logfactorial.(M .- m) .- logfactorial.(m) .+ m .* log.(μ).+ (M .- m) .* log.(1 .- μ) + exp(sum(res)) * pdf(Normal(μ_γ₂_post, σ_γ₂_post), x) + end # end funciton + evaluated_denisty = density_function.(grid1) + evaluated_denisty ./= sum(evaluated_denisty) + # sample acording to evaluation + grid1[rand(Categorical(evaluated_denisty))] +end # end funciton + +function sample_M_cond(n, N, m, γ₁, γ₂) + # compute ξ, μ + μ = (N .^ γ₁) .* ((n ./ N) .^ γ₂) + μ = 1 ./ (1 .+ exp.(-μ)) + ξ = N .^ γ₁ + # draw M-m vector from poisson intependently + M_minus_m = reduce(vcat, rand.(Poisson.(ξ .* (1 .- μ)), 1)) + # return M = (M-m) + increment + m + M_minus_m +end # end funciton + +function gibbs_sampler_binomial_model(start, grid, iter, n, N, m, μ_γ₁, μ_γ₂, σ_γ₁, σ_γ₂, ρ, ε = 1e-6) + # create storage vectors + M = start[1] + γ₁ = start[2] + γ₂ = start[3] + + storage = [[M], [γ₁], [γ₂]] + + for k in iter + # sample M conditional on γ₁ and γ₂ + M = sample_M_cond(n, N, m, γ₁, γ₂) + #println(M) + # sample γ₁ conditional on M and γ₂ + γ₁ = sample_gamma_1_cond( + grid, n, N, γ₂, + M, m, μ_γ₁, μ_γ₂, + ρ, σ_γ₁, σ_γ₂, ε + ) + #println(γ₁) + # sample γ₂ conditional on γ₁ and M + γ₂ = sample_gamma_2_cond( + grid, n, N, γ₁, + M, m, μ_γ₁, μ_γ₂, + ρ, σ_γ₁, σ_γ₂ + ) + #println(γ₂) + # store them + append!(storage[1], M) + append!(storage[2], γ₁) + append!(storage[3], γ₂) + end # end for + # return stored values + storage +end # end funciton \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e41f6b1..130d04f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using Test # TODO:: CSV is only used for testing, maybe specify that in Project.toml using CSV, DataFrames -@testset "UnobservedCountEstimation.jl" begin +@testset "zhang_model.jl" begin df = CSV.read(pwd() * "/test_csv_zhang.csv", DataFrame) a = zhang_model(df[:, :m], df[:, :N], df[:, :n]; start = "glm") @@ -15,3 +15,18 @@ using CSV, DataFrames @test b[1][4] ≈ sum(df[:, :ξ]) rtol=.075 @test b[1][4] ≈ sum(df[:, :M]) rtol=.075 end + +# for now +#= @testset "binomial_model.jl" begin + df = CSV.read(pwd() * "/test_csv_binomial.csv", DataFrame) + + a = binomial_model(df[:, :m], df[:, :N], df[:, :n]; start = "glm") + b = binomial_model(df[:, :m], df[:, :N], df[:, :n]; start = "lm") + + @test a[1][4] ≈ sum(df[:, :ξ]) rtol=.075 + @test a[1][4] ≈ sum(df[:, :M]) rtol=.075 + + @test b[1][4] ≈ sum(df[:, :ξ]) rtol=.075 + @test b[1][4] ≈ sum(df[:, :M]) rtol=.075 +end =# + diff --git a/test/test_csv_binomial.csv b/test/test_csv_binomial.csv new file mode 100644 index 0000000..fd0a262 --- /dev/null +++ b/test/test_csv_binomial.csv @@ -0,0 +1,21 @@ +N,n,m,M,ξ +4946,501,616,907,902.4087332179555 +4886,480,661,965,893.6403502565424 +4935,506,608,887,900.8027961854027 +4943,488,598,895,901.9708212977752 +5079,535,647,943,921.7700017557884 +5054,516,587,897,918.1384813234845 +5027,521,658,922,914.2124012586877 +4890,507,604,862,894.2255764308815 +5022,490,593,868,913.4848871407867 +4928,477,589,888,899.7804635417899 +4983,513,636,930,907.8052912210046 +4968,485,595,917,905.6184666779144 +5144,543,716,970,931.195284188665 +4993,507,643,942,909.2624427471206 +4854,472,605,879,888.9550830898875 +4988,499,588,884,908.5339400169311 +4977,489,617,917,906.9307196114336 +5053,524,617,902,917.9931458823555 +4950,531,629,880,902.9925331538545 +4845,492,551,837,887.6362400289374