Skip to content

Commit

Permalink
update rank function to use pure python library and move away from ml…
Browse files Browse the repository at this point in the history
…rose-hiive
  • Loading branch information
sotorrio1 committed Aug 15, 2023
1 parent cb85b8f commit b8eb618
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 22 deletions.
26 changes: 6 additions & 20 deletions foqus_lib/framework/sdoe/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
# "https://github.com/CCSI-Toolset/FOQUS".
#################################################################################
"""
Candidate ordering by TSP Optimization
Code adopted from:
https://mlrose.readthedocs.io/en/stable/source/tutorial2.html
Candidate ordering by TSP Optimization using python-tsp
https://pypi.org/project/python-tsp/
"""
import logging
import os
import numpy as np
import mlrose_hiive as mlrose
from python_tsp.exact import solve_tsp_dynamic_programming
from .df_utils import load, write

_log = logging.getLogger("foqus." + __name__)
Expand All @@ -40,26 +37,15 @@ def mat2tuples(mat):
return lte


def rank(fnames, ga_max_attempts=25):
def rank(fnames):
"""return fnames ranked"""
dist_mat = np.load(fnames["dmat"])
dist_list = mat2tuples(dist_mat)

# define fitness function object
fitness_dists = mlrose.TravellingSales(distances=dist_list)

# define optimization problem object
n_len = dist_mat.shape[0]
problem_fit = mlrose.TSPOpt(length=n_len, fitness_fn=fitness_dists, maximize=False)

# solve problem using the genetic algorithm
best_state = mlrose.genetic_alg(
problem_fit, mutation_prob=0.2, max_attempts=ga_max_attempts, random_state=2
)[0]
permutation, _distance = solve_tsp_dynamic_programming(dist_mat)

# retrieve ranked list
cand = load(fnames["cand"])
ranked_cand = cand.loc[best_state]
ranked_cand = cand.loc[reversed(permutation)]

# save the output
fname, ext = os.path.splitext(fnames["cand"])
Expand Down
2 changes: 1 addition & 1 deletion foqus_lib/framework/sdoe/test/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_rank():
fnames = {"cand": str(cand_fn), "dmat": str(dmat_fn)}

# Make the actual call
fname_ranked = order.rank(fnames, ga_max_attempts=5)
fname_ranked = order.rank(fnames)

# Ranked results as a dataframe
ret_ranked_df = df_utils.load(fname_ranked)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"boto3",
"cma",
"matplotlib<3.6",
"mlrose_hiive==2.1.3",
"python-tsp",
"joblib<1.3", # CCSI-Toolset/FOQUS#1154
"mplcursors",
"numpy",
Expand Down

0 comments on commit b8eb618

Please sign in to comment.