Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Oct 9, 2024
1 parent 9344ab9 commit c7c861b
Showing 1 changed file with 45 additions and 17 deletions.
62 changes: 45 additions & 17 deletions test/scicloj/ml/xgboost_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
[tech.v3.dataset.categorical :as ds-cat]
[scicloj.metamorph.ml.gridsearch :as ml-gs]))



(deftest basic
(verify/basic-regression {:model-type :xgboost/regression} 0.22))
Expand Down Expand Up @@ -105,23 +105,23 @@
predictions (ml/predict test-ds model)
loss
(loss/classification-accuracy
(->
predictions
ds-cat/reverse-map-categorical-xforms
(get "species"))
(->
test-ds
ds-cat/reverse-map-categorical-xforms
(get "species")))]
(->
predictions
ds-cat/reverse-map-categorical-xforms
(get "species"))
(->
test-ds
ds-cat/reverse-map-categorical-xforms
(get "species")))]

(is (= 0.9555555555555556 loss))
(is (=

(is (=
[{:importance-type "gain", :colname "petal_width", :gain 3.0993214419727266}
{:importance-type "gain", :colname "petal_length", :gain 2.8288314797695904}
{:importance-type "gain", :colname "sepal_width", :gain 0.272344306208}
{:importance-type "gain", :colname "sepal_length", :gain 0.12677490274290323}]
{:importance-type "gain", :colname "sepal_length", :gain 0.12677490274290323}]

(ds/rows
(ml/explain model))))))

Expand All @@ -142,7 +142,7 @@
(get "Survived")))]
(assoc model :accuracy accuracy)))

(deftest titanic
(deftest titanic
(let [titanic (-> (ds/->dataset "test/data/titanic.csv")
(ds/drop-columns ["Name"])
(ds/update-column "Survived" (fn [col]
Expand Down Expand Up @@ -171,8 +171,8 @@
ds-cat/reverse-map-categorical-xforms
(get "Survived")))




opt-map (merge {:model-type :xgboost/classification}
(ml/hyperparameters :xgboost/classification))
Expand All @@ -190,5 +190,33 @@
(-> models first :accuracy (* 100) Math/round)))))



(deftest no-cat-not-working
; https://github.com/scicloj/scicloj.ml.xgboost/issues/1
(let [ iris-no-cat-map
(->
(ds/->dataset "test/data/iris.csv" {:key-fn keyword})
(ds/categorical->number [:species] {} :float64)
(ds-mod/set-inference-target [:species])
(ds/assoc-metadata [:species] :categorical-map nil))

model
(ml/train iris-no-cat-map {:model-type :xgboost/classification
:num-class 3})]
(is ( = []
(:species (ml/predict iris-no-cat-map model))))))


(deftest no-cat-craching
; https://github.com/scicloj/scicloj.ml.xgboost/issues/1
(let [iris-no-cat-map
(->
(ds/->dataset "test/data/iris.csv" {:key-fn keyword})
(ds/categorical->number [:species] {} :float64)
(ds-mod/set-inference-target [:species])
(ds/assoc-metadata [:species] :categorical-map nil))]

(is ( thrown? Exception
(ml/train iris-no-cat-map {:model-type :xgboost/classification
})))
))

0 comments on commit c7c861b

Please sign in to comment.