From 898f92cc7c70a192b1cf199c28bc30028025f163 Mon Sep 17 00:00:00 2001 From: Carsten Behring Date: Mon, 14 Oct 2024 21:19:44 +0000 Subject: [PATCH] added handling of tid-text sparse columns --- CHANGELOG.md | 4 ++++ src/scicloj/ml/xgboost.clj | 38 +++++++++++++++++++++++--------- test/scicloj/ml/text_test.clj | 15 +++++++++---- test/scicloj/ml/xgboost_test.clj | 29 ++++++++++++------------ 4 files changed, 57 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4966b3d..f383076 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # ConstantChangeLog +## unreleased +- fixed issue #1 + + ## 6.1.0 Upgrade to xgboost4j_2.12 2.1.1 diff --git a/src/scicloj/ml/xgboost.clj b/src/scicloj/ml/xgboost.clj index dfd0011..d78d5b4 100644 --- a/src/scicloj/ml/xgboost.clj +++ b/src/scicloj/ml/xgboost.clj @@ -16,7 +16,8 @@ [tech.v3.datatype :as dtype] [tech.v3.datatype.errors :as errors] [tech.v3.tensor :as dtt] - [scicloj.ml.xgboost.csr :as csr]) + [scicloj.ml.xgboost.csr :as csr] + [scicloj.metamorph.ml.text :as text]) (:import [java.io ByteArrayInputStream ByteArrayOutputStream] [java.util LinkedHashMap Map] [ml.dmlc.xgboost4j LabeledPoint] @@ -198,8 +199,18 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that nil)) -(defn tidy-text-bow-ds->dmatrix [bow] - (let [zero-baseddocs-map +(defn tidy-text-bow-ds->dmatrix [feature-ds target-ds] + (def feature-ds feature-ds) + (def target-ds target-ds) + + ;(-> feature-ds :word .data .data) + ;(:label target-ds) + + (let [ds (if (some? target-ds) + (assoc feature-ds :label (:label target-ds)) + feature-ds) + bow (text/add-word-idx ds) + zero-baseddocs-map (zipmap (-> bow :document distinct) (range)) @@ -230,7 +241,9 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that (float-array (:values csr)) DMatrix$SparseType/CSR n-col)] - (.setLabel m (float-array labels)) + (def labels labels) + (when target-ds + (.setLabel m (float-array labels))) m)) @@ -287,11 +300,18 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that (defn ->dmatrix [feature-ds target-ds sparse-column n-sparse-columns] (if sparse-column - (sparse-feature->dmatrix feature-ds target-ds sparse-column n-sparse-columns) + (if (= (-> feature-ds (get sparse-column) first class) + SparseArray) + (sparse-feature->dmatrix feature-ds target-ds sparse-column n-sparse-columns) + (tidy-text-bow-ds->dmatrix feature-ds target-ds) + + ) + (dataset->dmatrix feature-ds target-ds))) + (defn- thaw-model [model-data] (-> (if (map? model-data) @@ -393,12 +413,12 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that (ds-mod/inference-target-label-map label-ds))] (train-from-dmatrix train-dmat feature-cnames target-cnames options label-map objective))) - (defn- predict [feature-ds thawed-model {:keys [target-columns target-categorical-maps options]}] (let [sparse-column-or-nil (:sparse-column options) dmatrix (->dmatrix feature-ds nil sparse-column-or-nil (:n-sparse-columns options)) prediction (.predict ^Booster thawed-model dmatrix) + predict-tensor (->> prediction (dtt/->tensor)) target-cname (first target-columns)] @@ -407,17 +427,15 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that (if (multiclass-objective? (options->objective options)) (-> (model/finalize-classification predict-tensor - (ds/row-count feature-ds) target-cname target-categorical-maps) - (tech.v3.dataset.modelling/probability-distributions->label-column - (first target-columns)) + (tech.v3.dataset.modelling/probability-distributions->label-column target-cname) (ds/update-column (first target-columns) #(vary-meta % assoc :column-type :prediction))) (model/finalize-regression predict-tensor target-cname)))) - + (defn- explain diff --git a/test/scicloj/ml/text_test.clj b/test/scicloj/ml/text_test.clj index d202466..d32c197 100644 --- a/test/scicloj/ml/text_test.clj +++ b/test/scicloj/ml/text_test.clj @@ -9,7 +9,9 @@ [scicloj.ml.xgboost.csr :as csr] [tablecloth.api :as tc] [tablecloth.column.api :as tcc] - [scicloj.metamorph.ml :as ml]) + [scicloj.metamorph.ml :as ml] + [tech.v3.dataset.column-filters :as cf] + [tech.v3.dataset :as ds]) (:import [java.util.zip GZIPInputStream] [ml.dmlc.xgboost4j.java XGBoost] [ml.dmlc.xgboost4j.java DMatrix DMatrix$SparseType])) @@ -26,7 +28,7 @@ [(first splitted) (dec (Integer/parseInt (second splitted)))])) #(str/split % #" ") - :max-lines 10000 + :max-lines 1000 :skip-lines 1) (tc/rename-columns {:meta :label}) (tc/drop-rows #(= "" (:word %))) @@ -48,8 +50,10 @@ text/->term-frequency text/add-word-idx) - m-train (xgboost/tidy-text-bow-ds->dmatrix bow-train) - m-test (xgboost/tidy-text-bow-ds->dmatrix bow-test) + m-train (xgboost/tidy-text-bow-ds->dmatrix (cf/feature bow-train) + (tc/select-columns bow-train [:label]) ) + m-test (xgboost/tidy-text-bow-ds->dmatrix (cf/feature bow-test) + (tc/select-columns bow-test [:label])) model (xgboost/train-from-dmatrix @@ -84,6 +88,9 @@ (float-array predition-test) (.getLabel m-test))] + (println :train-accuracy train-accuracy) + (println :test-accuracy test-accuracy) + (is (< 0.95 train-accuracy)) (is (< 0.54 test-accuracy)))) diff --git a/test/scicloj/ml/xgboost_test.clj b/test/scicloj/ml/xgboost_test.clj index 6bd37fc..82c4b0f 100644 --- a/test/scicloj/ml/xgboost_test.clj +++ b/test/scicloj/ml/xgboost_test.clj @@ -1,24 +1,24 @@ (ns scicloj.ml.xgboost-test - (:require [clojure.test :refer [deftest is]] - [fastmath.protocols :as protocols] - - [fastmath.vector :as vec] + (:require [clojure.data.csv :as csv] + [clojure.java.io :as io] + [clojure.string :as str] + [clojure.test :refer [deftest is]] [scicloj.metamorph.ml :as ml] [scicloj.metamorph.ml.gridsearch :as ml-gs] [scicloj.metamorph.ml.loss :as loss] + [scicloj.metamorph.ml.text :as text] [scicloj.metamorph.ml.verify :as verify] [scicloj.ml.smile.discrete-nb :as nb] [scicloj.ml.smile.nlp :as nlp] [scicloj.ml.xgboost] [tablecloth.api :as tc] - [tablecloth.column.api :as tcc] [tech.v3.dataset :as ds] [tech.v3.dataset.categorical :as ds-cat] [tech.v3.dataset.column-filters :as cf] [tech.v3.dataset.modelling :as ds-mod] - [tech.v3.datatype :as dtype] - [tech.v3.datatype.functional :as dfn] - [tech.v3.datatype :as dt])) + [tech.v3.datatype :as dtype] + [tech.v3.datatype.functional :as dfn]) + (:import [java.util.zip GZIPInputStream])) (deftest basic @@ -77,7 +77,7 @@ :sparse-column :bow-sparse :n-sparse-columns 100}) - + explanation (ml/explain model) test-ds (ds/head reviews 100) prediction (ml/predict test-ds model) @@ -88,13 +88,12 @@ :Score) (-> test-ds (ds-cat/reverse-map-categorical-xforms) - :Score)) - ] - (is ( > train-acc 0.97)))) + :Score))] + (is (> train-acc 0.97)))) -(deftest iris - (let [ src-ds (ds/->dataset "test/data/iris.csv") +(deftest iris + (let [src-ds (ds/->dataset "test/data/iris.csv") ds (-> src-ds (ds/categorical->number cf/categorical) (ds-mod/set-inference-target "species")) @@ -104,7 +103,7 @@ test-ds (:test-ds split-data) model (ml/train train-ds {:validate-parameters "true" :seed 123 - :verbosity 1 + :verbosity 0 :model-type :xgboost/classification}) predictions (ml/predict test-ds model) loss