Skip to content

Commit

Permalink
5-2d-decision-boundary
Browse files Browse the repository at this point in the history
5-2d-decision-boundary 删除 eval_model()方法 解决版本冲突问题
  • Loading branch information
math4mad committed Oct 2, 2023
1 parent be157f3 commit a32ac05
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 159 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"hash": "d0f60e00225a88211b6939da5990e36b",
"result": {
"markdown": "---\ntitle: \"5-2d-decision-boundary\"\nauthor: \"math4mads\"\ncode-fold: true\n---\n\n```julia\n\n```\n\n:::{.callout-note title=\"决策边界图\"}\n probml page 84,figure 2.13\n\n- 概率值预测方法 参见: https://discourse.julialang.org/t/extracting-values-from-univariatefinite/62794/3\n\n- 平面上每个点作为预测数据集\n- 通过学习的模型对平面上每个点给出概率值\n- 利用 `contour`方法绘出概率值\n:::\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nimport MLJ:fit!,predict,predict_mode,predict_mean,machine\nusing MLJ,GLMakie,Random,DataFrames\n\niris = load_iris(); \niris =DataFrame(iris);\nnums=100\n\niris[!, :target] = [r.target == \"virginica\" ? 1.0 : 0.0 for r in eachrow(iris)]\niris=coerce(iris, :target=> Multiclass )\ngdf=groupby(iris, :target)\nX,y=iris[:,3:4],iris[:,:target]\n\ncats=levels(y)\n\nfunction boundary_data(df,;n=nums)\n n1=n2=n\n xlow,xhigh=extrema(df[:,1])\n ylow,yhigh=extrema(df[:,2])\n tx = LinRange(xlow,xhigh,n1)\n ty = LinRange(ylow,yhigh,n2)\n x_test = mapreduce(collect, hcat, Iterators.product(tx, ty));\n x_test=MLJ.table(x_test')\n return tx,ty,x_test\nend\ntx,ty,x_test=boundary_data(X)\n\n\nusing MLJ\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels\n#X, y = make_blobs(centers = 2)\nmach = fit!(machine(LogisticClassifier(), X, y))\npredict(mach, X)\nfitted_params(mach)\nprobs=predict(mach, x_test)|>Array #返回分类概率值\nprobs_res=broadcast(pdf, probs, 1.0).|>(d->round(d,digits=2))|>d->reshape(d,nums,nums) #返回概率为1.0(\"virginica\")的概率值\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: Training machine(LogisticClassifier(lambda = 2.220446049250313e-16, …), …).\n┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, NamedTuple{(), Tuple{}}}\n│ optim_options: Optim.Options{Float64, Nothing}\n└ lbfgs_options: NamedTuple{(), Tuple{}} NamedTuple()\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\nimport MLJLinearModels ✔\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=2}\n```\n100×100 Matrix{Float64}:\n 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n ⋮ ⋮ ⋱ ⋮ \n 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.0 0.0 0.0 0.0 0.0 0.0 … 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.0 0.0 0.0 0.0 0.01 0.01 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.0 0.0 0.0 0.01 0.01 0.01 … 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.0 0.0 0.01 0.01 0.01 0.01 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.01 0.01 0.01 0.01 0.01 0.02 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.01 0.01 0.01 0.02 0.02 0.03 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n 0.01 0.01 0.02 0.02 0.03 0.04 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n```\n:::\n:::\n\n\n## plot results\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n fig=Figure(resolution=(1600,800))\n\n function plot_res_contour()\n\n ax = Axis(fig[1, 1], xlabel=\"petal-length\", ylabel=\"petal-width\", title=\"2d2class-contour\")\n contour!(ax, tx, ty, probs_res; labels=true)\n colors = [:red, :blue]\n for i in 1:2\n scatter!(ax, gdf[i][:, 3], gdf[i][:, 4], color=(colors[i], 0.8), marker=:circle, markersize=10, strokewidth=1, strokecolor=:black, label=gdf[i][1, 5] == 1 ? \"virginica\" : \"non-virginica\")\n end\n axislegend(ax, position=:lt)\n #save(\"./imgs/iris-logreg-2d-2class-contourf.png\",fig)\n fig\n end\n\n function plot_res_contourf()\n\n ax = Axis(fig[1, 2], xlabel=\"petal-length\", ylabel=\"petal-width\", title=\"2d2class-contourf\")\n contourf!(ax, tx, ty, probs_res; levels=6, colormap=(:heat, 0.5))\n #contourf!(ax,tx,ty,yhat;levels=length(cats),colormap=(:heat,0.5))\n colors = [:red, :blue]\n for i in 1:2\n scatter!(ax, gdf[i][:, 3], gdf[i][:, 4], color=(colors[i], 0.8), marker=:circle, markersize=10, strokewidth=1, strokecolor=:black, label=gdf[i][1, 5] == 1.0 ? \"virginica\" : \"non-virginica\")\n end\n axislegend(ax, position=:lt)\n #save(\"./imgs/iris-logreg-2d-2class-contourf.png\",fig)\n fig\n end\n\n plot_res_contourf()\n plot_res_contour()\n```\n\n::: {.cell-output .cell-output-display execution_count=3}\n![](5-2d-decision-boundary_files/figure-html/cell-3-output-1.png){}\n:::\n:::\n\n\n",
"supporting": [
"5-2d-decision-boundary_files"
],
"filters": [],
"includes": {}
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit a32ac05

Please sign in to comment.