From c7c861b9692b2732a4823a89775748aca309e596 Mon Sep 17 00:00:00 2001 From: Carsten Behring Date: Wed, 9 Oct 2024 13:24:51 +0000 Subject: [PATCH] more tests --- test/scicloj/ml/xgboost_test.clj | 62 +++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/test/scicloj/ml/xgboost_test.clj b/test/scicloj/ml/xgboost_test.clj index fc9ec94..6fed72a 100644 --- a/test/scicloj/ml/xgboost_test.clj +++ b/test/scicloj/ml/xgboost_test.clj @@ -14,7 +14,7 @@ [tech.v3.dataset.categorical :as ds-cat] [scicloj.metamorph.ml.gridsearch :as ml-gs])) - + (deftest basic (verify/basic-regression {:model-type :xgboost/regression} 0.22)) @@ -105,23 +105,23 @@ predictions (ml/predict test-ds model) loss (loss/classification-accuracy - (-> - predictions - ds-cat/reverse-map-categorical-xforms - (get "species")) - (-> - test-ds - ds-cat/reverse-map-categorical-xforms - (get "species")))] + (-> + predictions + ds-cat/reverse-map-categorical-xforms + (get "species")) + (-> + test-ds + ds-cat/reverse-map-categorical-xforms + (get "species")))] (is (= 0.9555555555555556 loss)) - - (is (= + + (is (= [{:importance-type "gain", :colname "petal_width", :gain 3.0993214419727266} {:importance-type "gain", :colname "petal_length", :gain 2.8288314797695904} {:importance-type "gain", :colname "sepal_width", :gain 0.272344306208} - {:importance-type "gain", :colname "sepal_length", :gain 0.12677490274290323}] - + {:importance-type "gain", :colname "sepal_length", :gain 0.12677490274290323}] + (ds/rows (ml/explain model)))))) @@ -142,7 +142,7 @@ (get "Survived")))] (assoc model :accuracy accuracy))) -(deftest titanic +(deftest titanic (let [titanic (-> (ds/->dataset "test/data/titanic.csv") (ds/drop-columns ["Name"]) (ds/update-column "Survived" (fn [col] @@ -171,8 +171,8 @@ ds-cat/reverse-map-categorical-xforms (get "Survived"))) - - + + opt-map (merge {:model-type :xgboost/classification} (ml/hyperparameters :xgboost/classification)) @@ -190,5 +190,33 @@ (-> models first :accuracy (* 100) Math/round))))) - +(deftest no-cat-not-working +; https://github.com/scicloj/scicloj.ml.xgboost/issues/1 + (let [ iris-no-cat-map + (-> + (ds/->dataset "test/data/iris.csv" {:key-fn keyword}) + (ds/categorical->number [:species] {} :float64) + (ds-mod/set-inference-target [:species]) + (ds/assoc-metadata [:species] :categorical-map nil)) + + model + (ml/train iris-no-cat-map {:model-type :xgboost/classification + :num-class 3})] + (is ( = [] + (:species (ml/predict iris-no-cat-map model)))))) + + +(deftest no-cat-craching +; https://github.com/scicloj/scicloj.ml.xgboost/issues/1 + (let [iris-no-cat-map + (-> + (ds/->dataset "test/data/iris.csv" {:key-fn keyword}) + (ds/categorical->number [:species] {} :float64) + (ds-mod/set-inference-target [:species]) + (ds/assoc-metadata [:species] :categorical-map nil))] + + (is ( thrown? Exception + (ml/train iris-no-cat-map {:model-type :xgboost/classification + }))) + ))