Skip to content

Commit

Permalink
[SPARK-50843][ML][PYTHON][CONNECT][FOLLOW-UP] Optimize the RPC for `T…
Browse files Browse the repository at this point in the history
…reeEnsembleModel.trees`

### What changes were proposed in this pull request?
Optimize the RPC for `TreeEnsembleModel.trees`

### Why are the changes needed?
to send the trees together

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49764 from zhengruifeng/ml_connect_trees.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 3, 2025
1 parent f840abb commit 8fc6a20
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
8 changes: 2 additions & 6 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,10 +2296,7 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeClassificationModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
n = self.getNumTrees
return [
DecisionTreeClassificationModel(self._call_java("getTree", i)) for i in range(n)
]
return [DecisionTreeClassificationModel(m) for m in self._call_java("trees").split(",")]
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]

@property
Expand Down Expand Up @@ -2789,8 +2786,7 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
n = self.getNumTrees
return [DecisionTreeRegressionModel(self._call_java("getTree", i)) for i in range(n)]
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]

def evaluateEachIteration(self, dataset: DataFrame) -> List[float]:
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,8 +1608,7 @@ class RandomForestRegressionModel(
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
n = self.getNumTrees
return [DecisionTreeRegressionModel(self._call_java("getTree", i)) for i in range(n)]
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]

@property
Expand Down Expand Up @@ -2000,8 +1999,7 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
n = self.getNumTrees
return [DecisionTreeRegressionModel(self._call_java("getTree", i)) for i in range(n)]
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]

def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,17 @@ private[connect] object MLHandler extends Logging {
.newBuilder()
.setObjRef(proto.ObjectRef.newBuilder().setId(id)))
.build()
case a: Array[_] if a.nonEmpty && a.forall(_.isInstanceOf[Model[_]]) =>
val ids = a.map { m =>
mlCache.register(m.asInstanceOf[Model[_]])
}
proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
proto.MlCommandResult.MlOperatorInfo
.newBuilder()
.setObjRef(proto.ObjectRef.newBuilder().setId(ids.mkString(","))))
.build()
case _ =>
val param = Serializer.serializeParam(attrResult)
proto.MlCommandResult.newBuilder().setParam(param).build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,12 +510,12 @@ private[ml] object MLUtils {
classOf[TreeEnsembleModel[_]],
Set(
"predictLeaf",
"trees",
"treeWeights",
"javaTreeWeights",
"getNumTrees",
"totalNumNodes",
"toDebugString",
"getTree")),
"toDebugString")),
(classOf[DecisionTreeClassificationModel], Set("featureImportances")),
(classOf[RandomForestClassificationModel], Set("featureImportances", "evaluate")),
(classOf[GBTClassificationModel], Set("featureImportances", "evaluateEachIteration")),
Expand Down

0 comments on commit 8fc6a20

Please sign in to comment.