From b8eb6180b4c13d4dad1d7674b22b003d7d8b2f29 Mon Sep 17 00:00:00 2001 From: Pedro Sotorrio Date: Tue, 15 Aug 2023 12:45:45 -0700 Subject: [PATCH] update rank function to use pure python library and move away from mlrose-hiive --- foqus_lib/framework/sdoe/order.py | 26 +++++---------------- foqus_lib/framework/sdoe/test/test_order.py | 2 +- setup.py | 2 +- 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/foqus_lib/framework/sdoe/order.py b/foqus_lib/framework/sdoe/order.py index d16987388..5bc85c759 100644 --- a/foqus_lib/framework/sdoe/order.py +++ b/foqus_lib/framework/sdoe/order.py @@ -13,16 +13,13 @@ # "https://github.com/CCSI-Toolset/FOQUS". ################################################################################# """ -Candidate ordering by TSP Optimization - -Code adopted from: -https://mlrose.readthedocs.io/en/stable/source/tutorial2.html - +Candidate ordering by TSP Optimization using python-tsp +https://pypi.org/project/python-tsp/ """ import logging import os import numpy as np -import mlrose_hiive as mlrose +from python_tsp.exact import solve_tsp_dynamic_programming from .df_utils import load, write _log = logging.getLogger("foqus." + __name__) @@ -40,26 +37,15 @@ def mat2tuples(mat): return lte -def rank(fnames, ga_max_attempts=25): +def rank(fnames): """return fnames ranked""" dist_mat = np.load(fnames["dmat"]) - dist_list = mat2tuples(dist_mat) - - # define fitness function object - fitness_dists = mlrose.TravellingSales(distances=dist_list) - - # define optimization problem object - n_len = dist_mat.shape[0] - problem_fit = mlrose.TSPOpt(length=n_len, fitness_fn=fitness_dists, maximize=False) - # solve problem using the genetic algorithm - best_state = mlrose.genetic_alg( - problem_fit, mutation_prob=0.2, max_attempts=ga_max_attempts, random_state=2 - )[0] + permutation, _distance = solve_tsp_dynamic_programming(dist_mat) # retrieve ranked list cand = load(fnames["cand"]) - ranked_cand = cand.loc[best_state] + ranked_cand = cand.loc[reversed(permutation)] # save the output fname, ext = os.path.splitext(fnames["cand"]) diff --git a/foqus_lib/framework/sdoe/test/test_order.py b/foqus_lib/framework/sdoe/test/test_order.py index 9c0d14806..f052cf8c8 100644 --- a/foqus_lib/framework/sdoe/test/test_order.py +++ b/foqus_lib/framework/sdoe/test/test_order.py @@ -142,7 +142,7 @@ def test_rank(): fnames = {"cand": str(cand_fn), "dmat": str(dmat_fn)} # Make the actual call - fname_ranked = order.rank(fnames, ga_max_attempts=5) + fname_ranked = order.rank(fnames) # Ranked results as a dataframe ret_ranked_df = df_utils.load(fname_ranked) diff --git a/setup.py b/setup.py index 16d4b4886..153e20080 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ "boto3", "cma", "matplotlib<3.6", - "mlrose_hiive==2.1.3", + "python-tsp", "joblib<1.3", # CCSI-Toolset/FOQUS#1154 "mplcursors", "numpy",