generated from math4mad/quarto-course-website-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
classification models comparison
- Loading branch information
Showing
24 changed files
with
4,230 additions
and
41 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
_freeze/machinelearning/1-classfication-comparison/execute-results/html.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"hash": "b74b4e6a239ca930fd8bf46b0d19c5d0", | ||
"result": { | ||
"markdown": "---\ntitle: \"classfication models comparison\"\n---\n\n## 1. load package\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\n import MLJ:predict,predict_mode\n using MLJ,GLMakie,DataFrames,Random\n Random.seed!(1222)\n```\n\n::: {.cell-output .cell-output-display execution_count=2}\n```\nTaskLocalRNG()\n```\n:::\n:::\n\n\n## 2. make data \n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n function circle_data()\n X, y = make_circles(400; noise=0.1, factor=0.3)\n df = DataFrame(X)\n df.y = y\n return df\n end\n function moons_data()\n X, y = make_moons(400; noise=0.1)\n df = DataFrame(X)\n df.y = y\n return df\n end\n function blob_data()\n X, y = make_blobs(400, 2; centers=2, cluster_std=[1.0, 2.0])\n df = DataFrame(X)\n df.y = y\n return df\n end\n #cat=df1.y|>levels|>unique\n colors=[:green, :purple]\n```\n\n::: {.cell-output .cell-output-display execution_count=3}\n```\n2-element Vector{Symbol}:\n :green\n :purple\n```\n:::\n:::\n\n\n## 3. define function \n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nfunction plot_origin_data(df)\n fig=Figure()\n ax=Axis(fig[1,1])\n local cat=df.y|>levels|>unique\n \n local colors=[:green, :purple]\n for (i,c) in enumerate(cat)\n d=df[y.==c,:]\n scatter!(ax, d[:,1],d[:,2],color=(colors[i],0.6))\n #@show d\n end\n fig\nend\n\nnums=100\nfunction boundary_data(df,;n=nums)\n n1=n2=n\n xlow,xhigh=extrema(df[:,:x1])\n ylow,yhigh=extrema(df[:,:x2])\n tx = LinRange(xlow,xhigh,n1)\n ty = LinRange(ylow,yhigh,n2)\n x_test = mapreduce(collect, hcat, Iterators.product(tx, ty));\n x_test=MLJ.table(x_test')\n return tx,ty,x_test\nend\n\nfunction plot_desc_boudary(fig,ytest,i;df=df1,row=1)\n tx,ty,xs,ys, xtest=boundary_data(df)\n local ax=Axis(fig[row,i],title=\"$(names[i])\")\n\n contourf!(ax, tx,ty,ytest,levels=length(cat),colormap=:phase)\n\n for (i,c) in enumerate(cat)\n d=df[y.==c,:]\n scatter!(ax, d[:,1],d[:,2],color=(colors[i],0.6))\n end\n hidedecorations!(ax)\nend\n\n```\n\n::: {.cell-output .cell-output-display execution_count=4}\n```\nplot_desc_boudary (generic function with 1 method)\n```\n:::\n:::\n\n\n## 4. define machine learning models\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\n using CatBoost.MLJCatBoostInterface\n SVC = @load SVC pkg=LIBSVM \n KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels\n DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree\n RandomForestClassifier = @load RandomForestClassifier pkg=DecisionTree\n CatBoostClassifier = @load CatBoostClassifier pkg=CatBoost\n BayesianLDA = @load BayesianLDA pkg=MultivariateStats\n Booster = @load AdaBoostStumpClassifier pkg=DecisionTree\n \n models=[KNNClassifier,DecisionTreeClassifier,RandomForestClassifier,CatBoostClassifier,BayesianLDA,SVC]\n names=[\"KNN\",\"DecisionTree\",\"RandomForest\",\"CatBoost\",\"BayesianLDA\",\"SVC\"]\n function _fit(df::DataFrame,m)\n X,y=df[:,1:2],df[:,3]\n _,_,xtest=boundary_data(df;n=nums)\n local predict= m==MLJLIBSVMInterface.SVC ? MLJ.predict : MLJ.predict_mode \n model=m()\n mach = machine(model, X, y)|>fit!\n yhat=predict(mach, xtest)\n ytest=yhat|>Array|>d->reshape(d,nums,nums)\n return ytest\nend\n\n\n\nfunction plot_desc_boudary(fig,ytest,i;df=df1,row=1)\n tx,ty,_=boundary_data(df)\n local y=df.y\n local ax=Axis(fig[row,i],title=\"$(names[i])\")\n cat=y|>levels|>unique\n contourf!(ax, tx,ty,ytest,levels=length(cat),colormap=:redsblues)\n\n for (i,c) in enumerate(cat)\n d=df[y.==c,:]\n scatter!(ax, d[:,1],d[:,2],color=(colors[i],0.6))\n end\n hidedecorations!(ax)\n \n\nend\n\nfunction plot_comparsion(testdata,df;row=1)\n \n for (i,data) in enumerate(testdata)\n plot_desc_boudary(fig,data,i;df=df,row=row)\n end\n fig\nend\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: For silent loading, specify `verbosity=0`. \n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\nimport MLJLIBSVMInterface ✔\nimport NearestNeighborModels ✔\nimport MLJDecisionTreeInterface ✔\nimport MLJDecisionTreeInterface ✔\nimport CatBoost ✔\nimport MLJMultivariateStatsInterface ✔\nimport MLJDecisionTreeInterface ✔\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=5}\n```\nplot_comparsion (generic function with 1 method)\n```\n:::\n:::\n\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nfig=Figure(resolution=(2100,1000))\nfunction plot_comparsion(testdata,df,row=1)\n \n for i in eachindex(testdata)\n plot_desc_boudary(fig,testdata[i],i;df=df,row=row)\n end\n fig\nend\n\n\n\ndf1=circle_data()\n\nytest1=[_fit(df1,m) for (i,m) in enumerate(models)]\n\ndf2=moons_data()\nytest2=[_fit(df2,m) for (i,m) in enumerate(models)]\n\ndf3=blob_data()\nytest3=[_fit(df3,m) for (i,m) in enumerate(models)]\n\ndfs=[df2,df1,df3]\nytests=[ytest2,ytest1,ytest3]\n\nfig=Figure(resolution=(2100,1000))\n\nfor (df, data,i) in zip(dfs,ytests,[1,2,3])\n plot_comparsion(data,df;row=i)\nend\n\nfig\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: Training machine(KNNClassifier(K = 5, …), …).\n[ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).\n[ Info: Training machine(RandomForestClassifier(max_depth = -1, …), …).\n[ Info: Training machine(CatBoostClassifier(iterations = 1000, …), …).\n[ Info: Training machine(BayesianLDA(method = gevd, …), …).\n[ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n[ Info: Training machine(KNNClassifier(K = 5, …), …).\n[ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).\n[ Info: Training machine(RandomForestClassifier(max_depth = -1, …), …).\n[ Info: Training machine(CatBoostClassifier(iterations = 1000, …), …).\n[ Info: Training machine(BayesianLDA(method = gevd, …), …).\n[ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n[ Info: Training machine(KNNClassifier(K = 5, …), …).\n[ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).\n[ Info: Training machine(RandomForestClassifier(max_depth = -1, …), …).\n[ Info: Training machine(CatBoostClassifier(iterations = 1000, …), …).\n[ Info: Training machine(BayesianLDA(method = gevd, …), …).\n[ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=6}\n![](1-classfication-comparison_files/figure-html/cell-6-output-2.png){}\n:::\n:::\n\n\n", | ||
"supporting": [ | ||
"1-classfication-comparison_files" | ||
], | ||
"filters": [], | ||
"includes": {} | ||
} | ||
} |
Binary file added
BIN
+770 KB
_freeze/machinelearning/1-classfication-comparison/figure-html/cell-6-output-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.