Skip to content

Commit

Permalink
added code for model-info
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Nov 26, 2024
1 parent d245740 commit 004de65
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ All notable changes to this project will be documented in this file. This change

## unreleased
- support "all types" for classification target column
- added partial support for register individual models

## 0.1.3
- use tribuo 4.3.1
Expand Down
6 changes: 3 additions & 3 deletions deps.edn
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{:paths ["src" "resources"]
:deps {org.clojure/clojure {:mvn/version "1.12.0"}
org.scicloj/metamorph.ml {:mvn/version "0.9.0"}
cheshire/cheshire {:mvn/version "5.12.0"}
techascent/tech.ml.dataset {:mvn/version "7.029"}
org.scicloj/metamorph.ml {:mvn/version "0.10.4"}
cheshire/cheshire {:mvn/version "5.13.0"}
techascent/tech.ml.dataset {:mvn/version "7.034"}
;; tribuo core deps
org.tribuo/tribuo-classification-core {:mvn/version "4.3.1"}
org.tribuo/tribuo-regression-core {:mvn/version "4.3.1"}
Expand Down
146 changes: 57 additions & 89 deletions src/scicloj/ml/tribuo.clj
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
(ns scicloj.ml.tribuo
(:require
[fastmath.stats :as stats]
[scicloj.metamorph.ml :as ml]
[tech.v3.dataset :as ds]
[tech.v3.dataset.modelling :as ds-mod]
[tech.v3.datatype :as dt]
[tech.v3.datatype.functional :as fun]
[fastmath.stats :as stats]
[tech.v3.datatype.errors :as errors]
[tech.v3.libs.tribuo :as tribuo]
[tablecloth.api :as tc])
(:import [org.tribuo.regression.evaluation RegressionEvaluator]
[org.tribuo.regression Regressor]))
[tech.v3.datatype.functional :as fun]
[tech.v3.libs.tribuo :as tribuo])
(:import
[org.tribuo.regression Regressor]
[org.tribuo.regression.evaluation RegressionEvaluator]))

(defn- make-trainer [options]
(tribuo/trainer (:tribuo-components options)
Expand All @@ -27,10 +27,10 @@
(dt/cast revere-mapped-elemn target-datatype)))

(defn- cast-target [prediction-ds column target-datatype target-categorical-maps]
(println :prediction--meta
(-> prediction-ds (tc/head) (get column) meta)
:data
(-> prediction-ds (tc/head) (get column) seq))
;; (println :prediction--meta
;; (-> prediction-ds (tc/head) (get column) meta)
;; :data
;; (-> prediction-ds (tc/head) (get column) seq))
(->
(ds/update-column
prediction-ds
Expand All @@ -55,42 +55,35 @@




(ml/define-model! :scicloj.ml.tribuo/classification

(fn [feature-ds target-ds options]

(defn- train-classification [feature-ds target-ds options]
;; (println :target-ts--meta
;; (map meta (-> target-ds vals))
;; :data
;; (map #(take 5 %) (vals target-ds)))
(let [target-data-type (-> target-ds ds/columns first meta :datatype)]

(when (= :object target-data-type)
(errors/throwf ":object target column not supported"))

{:model-instance
(tribuo/train-classification (make-trainer options)
(ds/append-columns feature-ds (ds/columns target-ds)))}))

(let [target-data-type (-> target-ds ds/columns first meta :datatype)]

(when (= :object target-data-type)
(errors/throwf ":object target column not supported"))
(defn predict-classification [feature-ds thawed-model {:keys [model-data
target-columns
target-categorical-maps
target-datatypes] :as model}]

{:model-instance
(tribuo/train-classification (make-trainer options)
(ds/append-columns feature-ds (ds/columns target-ds)))}))
(let [model-instance (:model-instance model-data)
target-column-name (first target-columns)
prediction
(->
(tribuo/predict-classification model-instance
feature-ds)
(post-process-prediction-classification target-column-name target-datatypes target-categorical-maps))]
(ds/assoc-metadata prediction [target-column-name] :categorical-map (get target-categorical-maps target-column-name))))


(fn [feature-ds thawed-model {:keys [model-data
target-columns
target-categorical-maps
target-datatypes] :as model}]

(let [model-instance (:model-instance model-data)
target-column-name (first target-columns)
prediction
(->
(tribuo/predict-classification model-instance
feature-ds)
(post-process-prediction-classification target-column-name target-datatypes target-categorical-maps))]
(ds/assoc-metadata prediction [target-column-name] :categorical-map (get target-categorical-maps target-column-name))))
{})


(defn evaluate [model]
Expand Down Expand Up @@ -162,62 +155,37 @@
(ds/add-column (ds/new-column :.resid residuos)))))


(defn- train-regression [feature-ds target-ds options]
{:target-ds target-ds
:feature-ds feature-ds
:model
(tribuo/train-regression (make-trainer options) (ds/append-columns feature-ds (ds/columns target-ds)))})

(defn- predict-regression [feature-ds thawed-model {:keys [model-data target-columns]}]
(let [model (:model model-data)]
(->
(tribuo/predict-regression model feature-ds)
(post-process-prediction-regression (first target-columns)))))

(ml/define-model! :scicloj.ml.tribuo/regression
(fn [feature-ds target-ds options]
{:target-ds target-ds
:feature-ds feature-ds
:model
(tribuo/train-regression (make-trainer options) (ds/append-columns feature-ds (ds/columns target-ds)))})

(fn [feature-ds thawed-model {:keys [model-data target-columns]}]
(let [model (:model model-data)]
(->
(tribuo/predict-regression model feature-ds)
(post-process-prediction-regression (first target-columns)))))
train-regression
predict-regression
{:glance-fn glance-fn-regression
:augment-fn augment-fn-regression})


(comment

(import '[com.oracle.labs.mlrg.olcut.config DescribeConfigurable]
'[com.oracle.labs.mlrg.olcut.config Configurable]
'[org.tribuo.regression.sgd.linear LinearSGDTrainer])


(defn configurable->docu [class]
(->>
(DescribeConfigurable/generateFieldInfo class)
vals
(map (fn [field-info]
(def field-info field-info)
(map :name
(:members (clojure.reflect/reflect field-info)))
{:name (.name field-info)
:description (.description field-info)
:type (.getGenericType (.field field-info))
:default (.defaultVal field-info)}))))

(def tribuo-trainers
["org.tribuo.regression.liblinear.LibLinearRegressionTrainer"
"org.tribuo.regression.liblinear.LinearRegressionType"
"org.tribuo.classification.ensemble.AdaBoostTrainer"
"org.tribuo.classification.dtree.CARTClassificationTrainer"])


(defn safe-class-for-name [s]
(try
(Class/forName s)
(catch Exception e nil)))


(def trainer-classes
(->>
(map safe-class-for-name tribuo-trainers)
(remove nil?)))

(map
configurable->docu
trainer-classes)
)
(ml/define-model! :scicloj.ml.tribuo/classification
train-classification
predict-classification
{})

(ml/define-model! :scicloj.ml.tribuo/classification
train-classification
predict-classification
{})


;(model-info/register-models train-classification predict-classification train-regression predict-regression)



93 changes: 93 additions & 0 deletions src/scicloj/ml/tribuo/model_info.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
(ns scicloj.ml.tribuo.model-info
(:require
[clojure.java.classpath]
[clojure.reflect]
[clojure.string :as str]
[scicloj.metamorph.ml :as ml])
(:import
[com.oracle.labs.mlrg.olcut.config DescribeConfigurable])
)

(defn all-configurables[interface]

(->> (clojure.java.classpath/classpath-jarfiles)
(filter (fn [^java.util.jar.JarFile jf]
(re-matches #".*tribuo.*" (.getName jf))))
(mapcat clojure.java.classpath/filenames-in-jar)
(map (fn [class-filename]
(try (some-> class-filename
(str/replace #"/" ".")
(str/replace #"\.class$" "")
(Class/forName))
(catch Exception _ nil))))
(filter (fn [cls]
(->> cls
supers
(some #(= % interface
;org.tribuo.Trainer
;com.oracle.labs.mlrg.olcut.config.Configurable

)))))))


(defn- configurable->docu [class]
(->>
(DescribeConfigurable/generateFieldInfo class)
vals
(map (fn [field-info]
{:name (.name field-info)
:description (.description field-info)
:type (.getGenericType (.field field-info))
:default (.defaultVal field-info)}))))


(defn- safe-configurable->docu [class]
{:class class
:options
(try
(configurable->docu class)
(catch Exception _ nil))})

(defn- trainer-infos []
(->> (all-configurables org.tribuo.Trainer)
(map safe-configurable->docu)
(remove #(empty? (:options %)))))


(defn- class->tribuo-url [class]
(if (nil? class)
""
(str "https://tribuo.org/learn/4.3/javadoc/"
(str/replace (.getName class)
"." "/")
".html")))

(defn- train-wrapper [trainer-class train-fn]
(fn [feature-ds target-ds options]
(train-fn feature-ds target-ds
(assoc options :trainer-class-name trainer-class))))

(defn register-models [train-classification predict-classification train-regression predict-regression ]
(run!
(fn [trainer-info]
(let [opts {:options (:options trainer-info)
:documentation {:javadoc (class->tribuo-url (:class trainer-info))}}
fq-class-name (.getName (:class trainer-info))

name-pieces (str/split fq-class-name #"\.")
type (nth name-pieces 2)
trainer-name (last name-pieces) ;(-> name-pieces last (str/replace "Trainer" "") csk/->kebab-case)

[train-fn predict-fn]
(cond
(str/starts-with? fq-class-name "org.tribuo.classification")
[train-classification predict-classification]
(str/starts-with? fq-class-name "org.tribuo.regression")
[train-regression predict-regression]
:else nil)]
(when (some? train-fn)
(ml/define-model! (keyword "scicloj.ml.tribuo" (format "%s.%s" type trainer-name))
(train-wrapper fq-class-name train-fn)
predict-fn
opts))))
(trainer-infos)))
47 changes: 36 additions & 11 deletions test/scicloj/ml/linear_regression_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@
(toydata/diabetes-ds))




(t/deftest tidy
(let [tribuo-linear-sdg
(ml/train diabetes
{:model-type :scicloj.ml.tribuo/regression
:tribuo-components [{:name "squared"
:type "org.tribuo.regression.sgd.objectives.SquaredLoss"}
{:name "trainer"
:type "org.tribuo.regression.sgd.linear.LinearSGDTrainer"
:properties {
:epochs "100"
:minibatchSize "1"
:objective "squared"}}]
:tribuo-trainer-name "trainer"})]
(ml/train
diabetes
{:model-type :scicloj.ml.tribuo/regression
:tribuo-components [{:name "squared"
:type "org.tribuo.regression.sgd.objectives.SquaredLoss"}
{:name "trainer"
:type "org.tribuo.regression.sgd.linear.LinearSGDTrainer"
:properties {:epochs "100"
:minibatchSize "1"
:objective "squared"}}]
:tribuo-trainer-name "trainer"})]


(is (= [{:disease-progression 163.65426518599335} {:disease-progression 113.19672792128675} {:disease-progression 156.74746391022254} {:disease-progression 151.2497638618952} {:disease-progression 139.5246990528341}]
Expand Down Expand Up @@ -53,3 +55,26 @@
(ml/augment diabetes)
(ds/head 10)
(ds/rows))))))



(comment

(ml/train
diabetes
{:model-type :scicloj.ml.tribuo/regression.linear-sgd
;; trainer options
:epochs 100
:minibatchSize 1})


(ml/train
diabetes
{:model-type :scicloj.ml.tribuo/regression.linear-sgd
;; train options
:epochs 100
:minibatchSize 1
:objective "squared"
;; other components and maybe options
:tribuo-components [{:name "squared"
:type "org.tribuo.regression.sgd.objectives.SquaredLoss"}]}))

0 comments on commit 004de65

Please sign in to comment.