diff --git a/include/ensmallen.hpp b/include/ensmallen.hpp index a6ca139dd..8963d15c0 100644 --- a/include/ensmallen.hpp +++ b/include/ensmallen.hpp @@ -120,6 +120,8 @@ #include "ensmallen_bits/agemoea/agemoea.hpp" #include "ensmallen_bits/moead/moead.hpp" #include "ensmallen_bits/nsga2/nsga2.hpp" +#include "ensmallen_bits/nsga3/normalization.hpp" +#include "ensmallen_bits/nsga3/nsga3.hpp" #include "ensmallen_bits/padam/padam.hpp" #include "ensmallen_bits/parallel_sgd/parallel_sgd.hpp" #include "ensmallen_bits/pso/pso.hpp" diff --git a/include/ensmallen_bits/nsga3/normalization.hpp b/include/ensmallen_bits/nsga3/normalization.hpp new file mode 100644 index 000000000..01f0bd436 --- /dev/null +++ b/include/ensmallen_bits/nsga3/normalization.hpp @@ -0,0 +1,235 @@ +/** + * @file normalization.hpp + * @author Satyam Shukla + * + * The Optimized normalization technique as described in Investigating the + * normalization procedure of nsga-iii. + * + * ensmallen is free software; you may redistribute it and/or modify it under + * the terms of the 3-clause BSD license. You should have received a copy of + * the 3-clause BSD license along with ensmallen. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef ENSMALLEN_NSGA3_NORMALIZATION_HPP +#define ENSMALLEN_NSGA3_NORMALIZATION_HPP + +namespace ens { + +/** + * + * This normalization technique is an improved version of the algorithm mentioned + * in the NSGA III paper. + * + * This class solves the problem of negative intercepts and non unique hyperplane + * by using worst point estimation as a backup value when ther are any abnormal values + * in the nadir point estimation. + * + * For more information, see the following: + * + * @code + * @inproceedings{blank2019investigating, + * title={Investigating the normalization procedure of NSGA-III}, + * author={Blank, Julian and Deb, Kalyanmoy and Roy, Proteek Chandan}, + * booktitle={International Conference on Evolutionary Multi-Criterion Optimization}, + * pages={229--240}, + * year={2019}, + * organization={Springer} + * } + * @endcode + */ +template +class Normalization +{ + public: + + typedef typename MatType::elem_type ElemType; + + /** + * @param dimension The no. of elements in a single point. + */ + Normalization(size_t dimensions = 3): + dimensions(dimensions), + idealPoint(arma::Col(dimensions, + arma::fill::value(arma::datum::inf))), + worstPoint(arma::Col(dimensions, + arma::fill::value(-1 * arma::datum::inf))) + {/* Nothing to do here */} + + /** + * @param calculatedObjectives The given population. + * @param indexes The index of the lements of the front + */ + void update(const std::vector>& calculatedObjectives, + const std::vector indexes) + { + // Calculate the front worst, the population worst and the ideal point. + arma::Col worstOfFront(dimensions, + arma::fill::value(-1 * arma::datum::inf)); + arma::Col worstOfPop(dimensions, + arma::fill::value(-1 * arma::datum::inf)); + + for (const arma::Col member: calculatedObjectives) + { + idealPoint = arma::min(idealPoint, member); + worstPoint = arma::max(worstPoint, member); + worstOfPop = arma::max(worstOfPop, member); + } + + for (size_t index : indexes) + { + worstOfFront = arma::max(worstOfFront, calculatedObjectives[index]); + } + + arma::Mat f(dimensions, dimensions); + + // Find the extremes. + FindExtremes(calculatedObjectives, indexes, f); + vectorizedExtremes = f; + + // Update the nadir point. + GetNadirPoint(calculatedObjectives, indexes, worstOfPop, worstOfFront); + + } + + /** + * @param calculatedObjectives The given population. + * @param indexes The index of the lements of the front + * @param f The matrix to store the vector extremes. + * @param useCurrentExtremes If the previously calculaeted extremes should + * be considered. + */ + void FindExtremes(const std::vector>& calculatedObjectives, + const std::vector& indexes, + arma::Mat& f, + bool useCurrentExtremes = true) + { + arma::Mat W(dimensions, dimensions, arma::fill::eye); + W = W + (W == 0) * 1e6; + arma::Mat vectorizedObjectives(dimensions, indexes.size()); + + for (size_t i = 0; i < indexes.size(); i++) + { + vectorizedObjectives.col(i) = calculatedObjectives[indexes[i]]; + } + + if (useCurrentExtremes) + { + vectorizedObjectives = arma::join_rows(vectorizedObjectives, + vectorizedExtremes); + } + vectorizedObjectives.each_col() -= idealPoint; + vectorizedObjectives = vectorizedObjectives + (vectorizedObjectives < 1e-3) * 0.; + + //Calculate ASF score and get the extreme vectors. + arma::Mat asfScore(vectorizedObjectives.n_cols, dimensions); + for(size_t i = 0; i < vectorizedObjectives.n_cols; i++) + { + for(size_t j = 0; j < dimensions; j++) + { + asfScore(i, j) = arma::max(W.col(j) % vectorizedObjectives.col(i)); + } + } + arma::urowvec extremes = arma::index_min(asfScore); + for(size_t i = 0; i < dimensions; i++) + { + if(extremes(i) >= indexes.size()) + { + f.col(i) = vectorizedExtremes.col(extremes(i) - indexes.size()); + } + else + { + f.col(i) = calculatedObjectives[indexes[extremes(i)]]; + } + } + } + + /** + * @param calculatedObjectives The given population. + * @param indexes The index of the lements of the front + * @param worstOfFront Worst point of the given front. + * @param worstOfPop Worst point of the population. + */ + void GetNadirPoint(const std::vector>& calculatedObjectives, + const std::vector& indexes, + const arma::Col worstOfFront, + const arma::Col worstOfPop) + { + try + { + arma::Mat M = vectorizedExtremes; + M.each_col() -= idealPoint; + arma::Col b(dimensions, arma::fill::ones); + arma::Col hyperplane = arma::solve(M.t(), b); + + if (hyperplane.has_inf() || hyperplane.has_nan() || (arma::accu(hyperplane < 0.0) > 0)) + { + throw 1024; + } + + arma::Col intercepts = 1.0 / hyperplane; + + if(arma::accu(arma::abs((M.t() * hyperplane) - b) > 1e-8) || + arma::accu(intercepts < 1e-6) > 0 || intercepts.has_inf() || + intercepts.has_nan()) + { + throw 1025; + } + + nadirPoint = idealPoint + intercepts; + nadirPoint = nadirPoint % (nadirPoint <= worstPoint); + nadirPoint += (nadirPoint == 0) % worstPoint; + } + catch(...) + { + nadirPoint = worstOfFront; + } + nadirPoint = ((nadirPoint - idealPoint) >= 1e-6) % nadirPoint + + ((nadirPoint - idealPoint) < 1e-6) % worstOfPop; + } + + //! Retrieve value of dimensions. + size_t Dimensions() const { return dimensions; } + //! Modify value of dimensions. + size_t& Dimensions() {return dimensions;} + + //! Retrieve value of ideal point. + arma::Col IdealPoint() const {return idealPoint; } + //! Modify value of ideal point. + arma::Col& IdealPoint() {return idealPoint; } + + //! Retrieve value of worst point. + arma::Col WorstPoint() const {return worstPoint; } + //! Modify value of worst point. + arma::Col& WorstPoint() {return worstPoint; } + + //! Retrieve value of nadir point. + arma::Col NadirPoint() const {return nadirPoint; } + //! Modify value of nadir point. + arma::Col& NadirPoint() {return nadirPoint; } + + //! Retrieve value of extreme points. + arma::Mat VectorizedExtremes() const {return vectorizedExtremes; } + //! Modify value of extreme points. + arma::Mat& VectorizedExtremes() {return vectorizedExtremes; } + + private: + + size_t dimensions; + + //The ideal point of the current population . + //! includes previous ideal point in calculations as well. + arma::Col idealPoint; + + // The worst point in the current population. + //! includes previous worst point in calculations as well. + arma::Col worstPoint; + + // The nadir point of the previous front. + arma::Col nadirPoint; + + // A Matrix containing the extreme vectors as columns. + arma::Mat vectorizedExtremes; +}; +} + +#endif diff --git a/include/ensmallen_bits/nsga3/nsga3.hpp b/include/ensmallen_bits/nsga3/nsga3.hpp new file mode 100644 index 000000000..9a58f025f --- /dev/null +++ b/include/ensmallen_bits/nsga3/nsga3.hpp @@ -0,0 +1,432 @@ +/** + * @file nsga3.hpp + * @author Satyam Shukla + * + * NSGA-III is a multi-objective optimization algorithm, widely used in + * many real-world applications. NSGA-III generates offsprings using + * crossover and mutation and then selects the next generation according + * to non-dominated-sorting and crowding distance comparison.The maintenance + * of diversity among population members in NSGA-III is aided by supplying and + * adaptively updating a number of well-spread reference points. + * + * ensmallen is free software; you may redistribute it and/or modify it under + * the terms of the 3-clause BSD license. You should have received a copy of + * the 3-clause BSD license along with ensmallen. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef ENSMALLEN_NSGA3_NSGA3_HPP +#define ENSMALLEN_NSGA3_NSGA3_HPP + +namespace ens { + +/** + * NSGA-III (Non-dominated Sorting Genetic Algorithm - III) is a multi-objective + * optimization algorithm. This class implements the NSGA-III algorithm. + * + * The algorithm works by generating a candidate population from a fixed + * starting point. At each stage of optimization, a new population of children + * is generated. This new population along with its predecessor is sorted using + * non-domination as the metric. Following this, the population is further + * segregated in fronts. A new population is generated from these fronts having + * size equal to that of the starting population. + * + * During evolution, two parents are randomly chosen using binary tournament + * selection. A pair of children are generated by crossing over these two + * candidates followed by mutation. + * + * The best front (Pareto optimal) is returned by the Optimize() method. + * + * For more information, see the following: + * + * @code + * @article{deb2013evolutionary, + * title={An evolutionary many-objective optimization algorithm using reference-point-based nondominated sorting approach, part I: solving problems with box constraints}, + * author={Deb, Kalyanmoy and Jain, Himanshu}, + * journal={IEEE transactions on evolutionary computation}, + * volume={18}, + * number={4}, + * pages={577--601}, + * year={2013}, + * publisher={IEEE} + * } + * @endcode + * + * NSGA-III can optimize arbitrary multi-objective functions. For more details, + * see the documentation on function types included with this distribution or + * on the ensmallen website. + */ +template +class NSGA3 +{ + public: + /** + * Constructor for the NSGA3 optimizer. + * + * The default values provided over here are not necessarily suitable for a + * given function. Therefore it is highly recommended to adjust the + * parameters according to the problem. + * + * @param referencePoints The reference points to be used. + * @param populationSize The number of candidates in the population. + * This should be atleast 4 in size and a multiple of 4. + * @param maxGenerations The maximum number of generations allowed for NSGA-II. + * @param crossoverProb The probability that a crossover will occur. + * @param mutationProb The probability that a mutation will occur. + * @param mutationStrength The strength of the mutation. + * @param lowerBound Lower bound of the coordinates of the initial population. + * @param upperBound Upper bound of the coordinates of the initial population. + */ + NSGA3(const arma::Mat& referencePoints, + const size_t populationSize = 100, + const size_t maxGenerations = 2000, + const double crossoverProb = 0.6, + const double distributionIndex = 20, + const double eta = 20, + const arma::vec& lowerBound = arma::zeros(1, 1), + const arma::vec& upperBound = arma::ones(1, 1)); + + /** + * Constructor for the NSGA3 optimizer. This constructor provides an overload + * to use `lowerBound` and `upperBound` of type double. + * + * The default values provided over here are not necessarily suitable for a + * given function. Therefore it is highly recommended to adjust the + * parameters according to the problem. + * + * @param referencePoints The reference points to be used. + * @param populationSize The number of candidates in the population. + * This should be atleast 4 in size and a multiple of 4. + * @param maxGenerations The maximum number of generations allowed for NSGA-II. + * @param crossoverProb The probability that a crossover will occur. + * @param mutationProb The probability that a mutation will occur. + * @param mutationStrength The strength of the mutation. + * @param lowerBound Lower bound of the coordinates of the initial population. + * @param upperBound Upper bound of the coordinates of the initial population. + */ + NSGA3(const arma::Mat& referencePoints, + const size_t populationSize = 100, + const size_t maxGenerations = 2000, + const double crossoverProb = 0.6, + const double distributionIndex = 20, + const double eta = 20, + const double lowerBound = 0, + const double upperBound = 1); + + /** + * Optimize a set of objectives. The initial population is generated using the + * starting point. The output is the best generated front. + * + * @tparam ArbitraryFunctionType std::tuple of multiple objectives. + * @tparam MatType Type of matrix to optimize. + * @tparam CallbackTypes Types of callback functions. + * @param objectives Vector of objective functions to optimize for. + * @param iterate Starting point. + * @param callbacks Callback functions. + * @return MatType::elem_type The minimum of the accumulated sum over the + * objective values in the best front. + */ + template + typename MatType::elem_type Optimize( + std::tuple& objectives, + MatType& iterate, + CallbackTypes&&... callbacks); + + //! Get the population size. + size_t PopulationSize() const { return populationSize; } + //! Modify the population size. + size_t& PopulationSize() { return populationSize; } + + //! Get the maximum number of generations. + size_t MaxGenerations() const { return maxGenerations; } + //! Modify the maximum number of generations. + size_t& MaxGenerations() { return maxGenerations; } + + //! Get the crossover rate. + double CrossoverRate() const { return crossoverProb; } + //! Modify the crossover rate. + double& CrossoverRate() { return crossoverProb; } + + //! Retrieve value of the distribution index. + double DistributionIndex() const { return distributionIndex; } + //! Modify the value of the distribution index. + double& DistributionIndex() { return distributionIndex; } + + //! Retrieve value of eta. + double Eta() const { return eta; } + //! Modify the value of eta. + double& Eta() { return eta; } + + //! Retrieve value of lowerBound. + const arma::vec& LowerBound() const { return lowerBound; } + //! Modify value of lowerBound. + arma::vec& LowerBound() { return lowerBound; } + + //! Retrieve value of upperBound. + const arma::vec& UpperBound() const { return upperBound; } + //! Modify value of upperBound. + arma::vec& UpperBound() { return upperBound; } + + //! Retrieve the Pareto optimal points in variable space. This returns an empty cube + //! until `Optimize()` has been called. + const arma::cube& ParetoSet() const { return paretoSet; } + + //! Retrieve the best front (the Pareto frontier). This returns an empty cube until + //! `Optimize()` has been called. + const arma::cube& ParetoFront() const { return paretoFront; } + + //! Get the reference points. + const arma::Mat& ReferencePoints() const { return referencePoints; } + //! Modify the reference points. + arma::Mat& ReferencePoints() { return referencePoints; } + + /** + * Retrieve the best front (the Pareto frontier). This returns an empty + * vector until `Optimize()` has been called. Note that this function is + * deprecated and will be removed in ensmallen 3.x! Use `ParetoFront()` + * instead. + */ + [[deprecated("use ParetoFront() instead")]] const std::vector& Front() + { + if (rcFront.size() == 0) + { + // Match the old return format. + for (size_t i = 0; i < paretoFront.n_slices; ++i) + { + rcFront.push_back(arma::mat(paretoFront.slice(i))); + } + } + + return rcFront; + } + + private: + /** + * Evaluate objectives for the elite population. + * + * @tparam ArbitraryFunctionType std::tuple of multiple function types. + * @tparam MatType Type of matrix to optimize. + * @param population The elite population. + * @param objectives The set of objectives. + * @param calculatedObjectives Vector to store calculated objectives. + */ + template + typename std::enable_if::type + EvaluateObjectives(std::vector&, + std::tuple&, + std::vector >&); + + template + typename std::enable_if::type + EvaluateObjectives(std::vector& population, + std::tuple& objectives, + std::vector >& + calculatedObjectives); + + /** + * Reproduce candidates from the elite population to generate a new + * population. + * + * @tparam BaseMatType Type of matrix to optimize. + * @param population The elite population. + * @param lowerBound Lower bound of the coordinates of the initial population. + * @param upperBound Upper bound of the coordinates of the initial population. + */ + template + void BinaryTournamentSelection(std::vector& population, + const MatType& lowerBound, + const MatType& upperBound); + + /** + * Crossover two parents to create a pair of new children. + * + * @tparam MatType Type of matrix to optimize. + * @param childA A newly generated candidate. + * @param childB Another newly generated candidate. + * @param parentA First parent from elite population. + * @param parentB Second parent from elite population. + * @param lowerBound Lower Bound of the offspring. + * @param upperBound Upper Boundn of the offspring. + */ + template + void Crossover(MatType& childA, + MatType& childB, + const MatType& parentA, + const MatType& parentB, + const MatType& lowerBound, + const MatType& upperBound); + + /** + * Mutate the coordinates for a candidate. + * + * @tparam MatType Type of matrix to optimize. + * @param candidate The candidate whose coordinates are being modified. + * @param mutationRate The probability of mutation. + * @param lowerBound Lower bound of the coordinates of the initial population. + * @param upperBound Upper bound of the coordinates of the initial population. + */ + template + void Mutate(MatType& candidate, + double mutationRate, + const MatType& lowerBound, + const MatType& upperBound); + + /** + * Sort the candidate population using their domination count and the set of + * dominated nodes. + * + * @tparam MatType Type of matrix to optimize. + * @param fronts The population is sorted into these Pareto fronts. The first + * front is the best, the second worse and so on. + * @param ranks The assigned ranks, used for crowding distance based sorting. + * @param calculatedObjectives The previously calculated objectives. + */ + template + void FastNonDominatedSort( + std::vector >& fronts, + std::vector& ranks, + std::vector& calculatedObjectives); + + /** + * Finding the distance of each point in the front from the line formed + * by pointA and pointB. + * + * @param distance The vector containing the distances of the points in the fron from the line. + * @param calculatedObjectives Reference to the current population evaluated Objectives. + * @param front The front of the current generation(indices of population). + * @param pointA The first point on the line. + * @param pointB The second point on the line. + */ + template + void PointToLineDistance(arma::Row& distances, + const std::vector& calculatedObjectives, + const std::vector& front, + const ColType& pointA, + const ColType& pointB); + + /** + * Finding the point in the reference front associated with the members in + * the given front and also stores the distance between the two points. + * + * @param refIndex Vector containing the index of the point in the refrence directons associated. + * @param dists Vector of distances from the corresponding point in the front to the associated reference direction. + * @param calculatedObjectives The points of the currently generated population. + * @param St The index of points belonging to the given front. + */ + template + void Associate(arma::urowvec& refIndex, + arma::Row& dists, + const std::vector& calculatedObjectives, + const std::vector& St); + + /** + * Find the niche count for each reference direction. + * + * @param count The no. of points associated with each niche direction. + * @param refIndex Vector containing the index of the point in the reference + * directons associated. + */ + void NicheCount(arma::Row& count, + const arma::urowvec& refIndex, + const std::vector& nextPopulation); + + /** + * The niche preserving operation to select the final points form the given front + * aranges the front in descending order of priority for the top K points in the front. + * + * @param K The no. of remaining points to select from the given front for the next population. + * @param nicheCount The count of no. of points associated to each reference point in St. + * @param refIndex The index of the rerference points associated with the in the given front. + * @param dists The distances of th points in the front form their associated reference points line. + * @param front The index of teh points within the population which are a part of the given front. + * @param population The set St (selected points). + */ + void Niching(size_t K, + arma::Row& nicheCount, + const arma::urowvec& refIndex, + const arma::Row& dists, + const std::vector& front, + std::vector& population); + + /** + * Operator to check if one candidate Pareto-dominates the other. + * + * A candidate is said to dominate the other if it is at least as good as the + * other candidate for all the objectives and there exists at least one + * objective for which it is strictly better than the other candidate. + * + * @tparam MatType Type of matrix to optimize. + * @param calculatedObjectives The previously calculated objectives. + * @param candidateP The candidate being compared from the elite population. + * @param candidateQ The candidate being compared against. + * @return true if candidateP Pareto dominates candidateQ, otherwise, false. + */ + template + bool Dominates( + std::vector& calculatedObjectives, + size_t candidateP, + size_t candidateQ); + + //! The number of objectives being optimised for. + size_t numObjectives; + + //! The numbeer of variables used per objectives. + size_t numVariables; + + //! The number of candidates in the population. + size_t populationSize; + + //! Maximum number of generations before termination criteria is met. + size_t maxGenerations; + + //! Probability that crossover will occur. + double crossoverProb; + + //! Probability that mutation will occur. + double mutationProb; + + //! Strength of the mutation. + double mutationStrength; + + //! The crowding degree of the mutation. Higher value produces a mutant + //! resembling its parent. + double distributionIndex; + + //! The distance parameters of the crossover distribution. + double eta; + + //! Lower bound of the initial swarm. + arma::vec lowerBound; + + //! Upper bound of the initial swarm. + arma::vec upperBound; + + //! The set of all the Pareto optimal points. + //! Stored after Optimize() is called. + arma::cube paretoSet; + + //! The set of all the Pareto optimal objective vectors. + //! Stored after Optimize() is called. + arma::cube paretoFront; + + //! The reference points. + arma::Mat referencePoints; + + //! A different representation of the Pareto front, for reverse compatibility + //! purposes. This can be removed when ensmallen 3.x is released! (Along + //! with `Front()`.) This is only populated when `Front()` is called. + std::vector rcFront; +}; + +} // namespace ens + +// Include implementation. +#include "nsga3_impl.hpp" + +#endif diff --git a/include/ensmallen_bits/nsga3/nsga3_impl.hpp b/include/ensmallen_bits/nsga3/nsga3_impl.hpp new file mode 100644 index 000000000..9b0e5c1d9 --- /dev/null +++ b/include/ensmallen_bits/nsga3/nsga3_impl.hpp @@ -0,0 +1,637 @@ +/** + * @file nsga3_impl.hpp + * @author Satyam Shukla + * + * Implementation of the NSGA3 algorithm. Used for multi-objective + * optimization problems on arbitrary functions. + * + * ensmallen is free software; you may redistribute it and/or modify it under + * the terms of the 3-clause BSD license. You should have received a copy of + * the 3-clause BSD license along with ensmallen. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more Information. + */ + +#ifndef ENSMALLEN_NSGA3_NSGA3_IMPL_HPP +#define ENSMALLEN_NSGA3_NSGA3_IMPL_HPP + +#include "nsga3.hpp" +#include +#include "normalization.hpp" + +namespace ens { + +template +inline NSGA3::NSGA3( + const arma::Mat& referencePoints, + const size_t populationSize, + const size_t maxGenerations, + const double crossoverProb, + const double distributionIndex, + const double eta, + const arma::vec& lowerBound, + const arma::vec& upperBound): + referencePoints(referencePoints), + numObjectives(0), + numVariables(0), + populationSize(populationSize), + maxGenerations(maxGenerations), + crossoverProb(crossoverProb), + distributionIndex(distributionIndex), + eta(eta), + lowerBound(lowerBound), + upperBound(upperBound) +{ /* Nothing to do here. */ } + +template +inline NSGA3::NSGA3( + const arma::Mat& referencePoints, + const size_t populationSize, + const size_t maxGenerations, + const double crossoverProb, + const double distributionIndex, + const double eta, + const double lowerBound, + const double upperBound): + referencePoints(referencePoints), + numObjectives(0), + numVariables(0), + populationSize(populationSize), + maxGenerations(maxGenerations), + crossoverProb(crossoverProb), + distributionIndex(distributionIndex), + eta(eta), + lowerBound(lowerBound * arma::ones(1, 1)), + upperBound(upperBound * arma::ones(1, 1)) +{ /* Nothing to do here. */ } + +//! Optimize the function. +template +template +typename MatType::elem_type NSGA3::Optimize( + std::tuple& objectives, + MatType& iterateIn, + CallbackTypes&&... callbacks) +{ + // Make sure for evolution to work at least four candidates are present. + if (populationSize < 4 && populationSize % 4 != 0) + { + throw std::logic_error("NSGA3::Optimize(): population size should be at" + " least 4, and, a multiple of 4!"); + } + + // Convenience typedefs. + typedef typename MatType::elem_type ElemType; + typedef typename MatTypeTraits::BaseMatType BaseMatType; + + BaseMatType& iterate = (BaseMatType&) iterateIn; + + // Make sure that we have the methods that we need. Long name... + traits::CheckArbitraryFunctionTypeAPI(); + RequireDenseFloatingPointType(); + + // Check if lower bound is a vector of a single dimension. + if (lowerBound.n_rows == 1) + lowerBound = lowerBound(0, 0) * arma::ones(iterate.n_rows, iterate.n_cols); + + // Check if upper bound is a vector of a single dimension. + if (upperBound.n_rows == 1) + upperBound = upperBound(0, 0) * arma::ones(iterate.n_rows, iterate.n_cols); + + // Check the dimensions of lowerBound and upperBound. + assert(lowerBound.n_rows == iterate.n_rows && "The dimensions of " + "lowerBound are not the same as the dimensions of iterate."); + assert(upperBound.n_rows == iterate.n_rows && "The dimensions of " + "upperBound are not the same as the dimensions of iterate."); + + numObjectives = sizeof...(ArbitraryFunctionType); + numVariables = iterate.n_rows; + + assert(numObjectives == referencePoints.n_rows && "The dimensions of " + "reference points do not match the number of functions."); + + // Cache calculated objectives. + std::vector > calculatedObjectives(populationSize); + + // Population size reserved to 2 * populationSize + 1 to accommodate + // for the size of intermediate candidate population. + std::vector population; + std::vector tempPopulation; + population.reserve(2 * populationSize + 1); + tempPopulation.reserve(populationSize); + + // Pareto fronts, initialized during non-dominated sorting. + // Stores indices of population belonging to a certain front. + std::vector > fronts; + // Initialised during non-dominated sorting. + std::vector ranks; + + //! Useful temporaries for float-like comparisons. + const BaseMatType castedLowerBound = arma::conv_to::from(lowerBound); + const BaseMatType castedUpperBound = arma::conv_to::from(upperBound); + + // Controls early termination of the optimization process. + bool terminate = false; + + // Generate the population based on a uniform distribution around the given + // starting point. + for (size_t i = 0; i < populationSize; i++) + { + population.push_back(arma::randu(iterate.n_rows, + iterate.n_cols) - 0.5 + iterate); + + // Constrain all genes to be within bounds. + population[i] = arma::min(arma::max(population[i], castedLowerBound), castedUpperBound); + } + Normalization hpn(numObjectives); + + Info << "NSGA3 initialized successfully. Optimization started." << std::endl; + + // Iterate until maximum number of generations is obtained. + Callback::BeginOptimization(*this, objectives, iterate, callbacks...); + + for (size_t generation = 1; generation <= maxGenerations && !terminate; generation++) + { + // Create new population of candidate from the present elite population. + // Have P_t, generate G_t using P_t. + BinaryTournamentSelection(population, castedLowerBound, castedUpperBound); + + // Evaluate the objectives for the new population. + calculatedObjectives.resize(population.size()); + std::fill(calculatedObjectives.begin(), calculatedObjectives.end(), + arma::Col(numObjectives, arma::fill::zeros)); + + EvaluateObjectives(population, objectives, calculatedObjectives); + + // Perform fast non dominated sort on P_t ∪ G_t. + ranks.resize(population.size()); + FastNonDominatedSort>(fronts, ranks, calculatedObjectives); + + hpn.update(calculatedObjectives, fronts[0]); + arma::Col denom = hpn.NadirPoint() - hpn.IdealPoint(); + + // S_t and P_t+1 declared. + std::vector selectedPoints; + std::vector nextPopulation; + + size_t index = 0; + while (nextPopulation.size() + fronts[index].size() < populationSize) + { + selectedPoints.insert(selectedPoints.end(), fronts[index].begin(), fronts[index].end()); + nextPopulation.insert(nextPopulation.end(), fronts[index].begin(), fronts[index].end()); + index++; + } + + if(nextPopulation.size() != populationSize) + { + selectedPoints.insert(selectedPoints.end(), fronts[index].begin(), fronts[index].end()); + + size_t lastFront = index; + + for (index = 0; index < selectedPoints.size(); index++) + { + calculatedObjectives[selectedPoints[index]] = + calculatedObjectives[selectedPoints[index]] - hpn.IdealPoint(); + calculatedObjectives[selectedPoints[index]] = + calculatedObjectives[selectedPoints[index]] / denom; + } + + // Find the associated reference directions to the selected points. + arma::urowvec refIndex(selectedPoints.size()); + arma::Row dists(selectedPoints.size()); + + Associate>(refIndex, dists, calculatedObjectives, + selectedPoints); + + // Calculate the niche count of S_t and performing the niching operation. + arma::Row count(referencePoints.n_cols, arma::fill::zeros); + + NicheCount(count, refIndex, nextPopulation); + Niching(populationSize - nextPopulation.size(), count, refIndex, + dists, fronts[lastFront], nextPopulation); + } + for (size_t i : nextPopulation) + { + tempPopulation.push_back(population[i]); + } + population = tempPopulation; + tempPopulation.erase(tempPopulation.begin(), tempPopulation.end()); + + terminate |= Callback::GenerationalStepTaken(*this, objectives, iterate, + calculatedObjectives, fronts, callbacks...); + } + EvaluateObjectives(population, objectives, calculatedObjectives); + // Set the candidates from the Pareto Set as the output. + paretoSet.set_size(population[0].n_rows, population[0].n_cols, + population.size()); + // The Pareto Set is stored, can be obtained via ParetoSet() getter. + for (size_t solutionIdx = 0; solutionIdx < population.size(); ++solutionIdx) + { + paretoSet.slice(solutionIdx) = + arma::conv_to::from(population[solutionIdx]); + } + + // Set the candidates from the Pareto Front as the output. + paretoFront.set_size(calculatedObjectives[0].n_rows, + calculatedObjectives[0].n_cols, population.size()); + // The Pareto Front is stored, can be obtained via ParetoFront() getter. + for (size_t solutionIdx = 0; solutionIdx < population.size(); ++solutionIdx) + { + paretoFront.slice(solutionIdx) = + arma::conv_to::from(calculatedObjectives[solutionIdx]); + } + + // Clear rcFront, in case it is later requested by the user for reverse + // compatibility reasons. + rcFront.clear(); + + // Assign iterate to first element of the Pareto Set. + iterate = population[fronts[0][0]]; + + Callback::EndOptimization(*this, objectives, iterate, callbacks...); + + ElemType performance = std::numeric_limits::max(); + + for (const arma::Col& objective: calculatedObjectives) + if (arma::accu(objective) < performance) + performance = arma::accu(objective); + + return performance; +} + +//! No objectives to evaluate. +template +template +typename std::enable_if::type +NSGA3::EvaluateObjectives( + std::vector&, + std::tuple&, + std::vector >&) +{ + // Nothing to do here. +} + +//! Evaluate the objectives for the entire population. +template +template +typename std::enable_if::type +NSGA3::EvaluateObjectives( + std::vector& population, + std::tuple& objectives, + std::vector >& calculatedObjectives) +{ + for (size_t i = 0; i < population.size(); i++) + { + calculatedObjectives[i](I) = std::get(objectives).Evaluate(population[i]); + } + EvaluateObjectives(population, objectives, + calculatedObjectives); +} + +//! Reproduce and generate new candidates. +template +template +inline void NSGA3::BinaryTournamentSelection( + std::vector& population, + const MatType& lowerBound, + const MatType& upperBound) +{ + std::vector children; + + while (children.size() < populationSize) + { + // Choose two random parents for reproduction from the elite population. + size_t indexA = arma::randi(arma::distr_param(0, populationSize - 1)); + size_t indexB = arma::randi(arma::distr_param(0, populationSize - 1)); + + // Make sure that the parents differ. + if (indexA == indexB) + { + if (indexB < populationSize - 1) + indexB++; + else + indexB--; + } + + // Initialize the children to the respective parents. + MatType childA = population[indexA], childB = population[indexB]; + + if(arma::randu() <= crossoverProb) + Crossover(childA, childB, population[indexA], population[indexB], + lowerBound, upperBound); + + Mutate(childA, 1.0 / static_cast(numVariables), + lowerBound, upperBound); + Mutate(childB, 1.0 / static_cast(numVariables), + lowerBound, upperBound); + + // Add the children to the candidate population. + children.push_back(childA); + children.push_back(childB); + } + + // Add the candidates to the elite population. + population.insert(std::end(population), std::begin(children), std::end(children)); +} + +//! Perform simulated binary crossover (SBX) of genes for the children. +template +template +inline void NSGA3::Crossover(MatType& childA, + MatType& childB, + const MatType& parentA, + const MatType& parentB, + const MatType& lowerBound, + const MatType& upperBound) +{ + //! Generates a child from two parent individuals + // according to the polynomial probability distribution. + arma::Cube parents(parentA.n_rows, parentA.n_cols, 2); + parents.slice(0) = parentA; + parents.slice(1) = parentB; + MatType current_min = arma::min(parents, 2); + MatType current_max = arma::max(parents, 2); + + if (arma::accu(parentA - parentB < 1e-14)) + { + childA = parentA; + childB = parentB; + return; + } + MatType current_diff = current_max - current_min; + current_diff.transform( [](typename MatType::elem_type val) + { return (val < 1e-10 ? 1e-10:val); } ); + + // Calculating beta used for the final crossover. + MatType beta1 = 1 + 2.0 * (current_min - lowerBound) / current_diff; + MatType beta2 = 1 + 2.0 * (upperBound - current_max) / current_diff; + MatType alpha1 = 2 - arma::pow(beta1, -(eta + 1)); + MatType alpha2 = 2 - arma::pow(beta2, -(eta + 1)); + + MatType us(arma::size(alpha1), arma::fill::randu); + arma::umat mask1 = us > (1.0 / alpha1); + MatType betaq1 = arma::pow(us % alpha1, 1. / (eta + 1)); + betaq1 = betaq1 % (mask1 != 1.0) + arma::pow((1.0 / (2.0 - us % alpha1)), 1.0 / (eta + 1)) % mask1; + arma::umat mask2 = us > (1.0 / alpha2); + MatType betaq2 = arma::pow(us % alpha2, 1 / (eta + 1)); + betaq2 = betaq2 % (mask1 != 1.0) + arma::pow((1.0 / (2.0 - us % alpha2)), 1.0 / (eta + 1)) % mask2; + + // Variables after the cross over for all of them. + MatType c1 = 0.5 * ((current_min + current_max) - betaq1 % current_diff); + MatType c2 = 0.5 * ((current_min + current_max) + betaq2 % current_diff); + c1 = arma::min(arma::max(c1, lowerBound), upperBound); + c2 = arma::min(arma::max(c2, lowerBound), upperBound); + + // Decision for the crossover between the two parents for each variable. + us.randu(); + childA = parentA % (us <= 0.5); + childB = parentB % (us <= 0.5); + us.randu(); + childA = childA + c1 % ((us <= 0.5) % (childA == 0)); + childA = childA + c2 % ((us > 0.5) % (childA == 0)); + childB = childB + c2 % ((us <= 0.5) % (childB == 0)); + childB = childB + c1 % ((us > 0.5) % (childB == 0)); +} + +//! Perform Polynomial mutation of the candidate. +template +template +inline void NSGA3::Mutate(MatType& candidate, + double mutationRate, + const MatType& lowerBound, + const MatType& upperBound) +{ + const size_t numVariables = candidate.n_rows; + for (size_t geneIdx = 0; geneIdx < numVariables; ++geneIdx) + { + // Should this gene be mutated? + if (arma::randu() > mutationRate) + continue; + + const double geneRange = upperBound(geneIdx) - lowerBound(geneIdx); + // Normalised distance from the bounds. + const double lowerDelta = (candidate(geneIdx) - lowerBound(geneIdx)) / geneRange; + const double upperDelta = (upperBound(geneIdx) - candidate(geneIdx)) / geneRange; + const double mutationPower = 1. / (distributionIndex + 1.0); + const double rand = arma::randu(); + double value, perturbationFactor; + if (rand < 0.5) + { + value = 2.0 * rand + (1.0 - 2.0 * rand) * + std::pow(upperDelta, distributionIndex + 1.0); + perturbationFactor = std::pow(value, mutationPower) - 1.0; + } + else + { + value = 2.0 * (1.0 - rand) + 2.0 *(rand - 0.5) * + std::pow(lowerDelta, distributionIndex + 1.0); + perturbationFactor = 1.0 - std::pow(value, mutationPower); + } + + candidate(geneIdx) += perturbationFactor * geneRange; + } + //! Enforce bounds. + candidate = arma::min(arma::max(candidate, lowerBound), upperBound); +} + +//! Find the distance of a front from a line formed by two points. +template +template +inline void NSGA3::PointToLineDistance( + arma::Row& distances, + const std::vector& calculatedObjectives, + const std::vector& front, + const ColType& pointA, + const ColType& pointB) +{ + arma::Row distancesTemp(front.size()); + ColType ba = pointB - pointA; + ColType pa; + + for (size_t i = 0; i < front.size(); i++) + { + size_t ind = front[i]; + + pa = (calculatedObjectives[ind] - pointA); + double t = arma::dot(pa, ba) / arma::dot(ba, ba); + distancesTemp[i] = std::pow(arma::accu(arma::pow((pa - t * ba), 2)), 0.5); + } + distances = distancesTemp; +} + +//! Sort population into Pareto fronts. +template +template +inline void NSGA3::FastNonDominatedSort( + std::vector >& fronts, + std::vector& ranks, + std::vector& calculatedObjectives) +{ + std::map dominationCount; + std::map > dominated; + + // Reset and initialize fronts. + fronts.clear(); + fronts.push_back(std::vector()); + + for (size_t p = 0; p < calculatedObjectives.size(); p++) + { + dominated[p] = std::set(); + dominationCount[p] = 0; + + for (size_t q = 0; q < calculatedObjectives.size(); q++) + { + if (Dominates>(calculatedObjectives, p, q)) + dominated[p].insert(q); + else if (Dominates>(calculatedObjectives, q, p)) + dominationCount[p] += 1; + } + + if (dominationCount[p] == 0) + { + ranks[p] = 0; + fronts[0].push_back(p); + } + } + + size_t i = 0; + + while (!fronts[i].empty()) + { + std::vector nextFront; + + for (size_t p: fronts[i]) + { + for (size_t q: dominated[p]) + { + dominationCount[q]--; + + if (dominationCount[q] == 0) + { + ranks[q] = i + 1; + nextFront.push_back(q); + } + } + } + + i++; + fronts.push_back(nextFront); + } + // Remove the empty final set. + fronts.pop_back(); +} + +template +inline void NSGA3::Niching(size_t K, + arma::Row& nicheCount, + const arma::urowvec& refIndex, + const arma::Row& dists, + const std::vector& front, + std::vector& population) +{ + arma::Row popMask(front.size(), arma::fill::zeros); + int nextPopSize = population.size(); + size_t k = 0; + while (k < K) + { + size_t jMin = arma::index_min(nicheCount); + std::vector I; + for (size_t i = 0; i < front.size(); i++) + { + if(refIndex[nextPopSize + i] == jMin && !popMask[i]) + I.push_back(i); + } + if (I.size() != 0) + { + size_t min = 0; + if(nicheCount[jMin] == 0) + { + for (size_t i = 0; i < I.size(); i++) + { + if(dists[nextPopSize + I[i]] < dists[nextPopSize + I[min]]) + { + min = i; + } + } + } + population.push_back(front[I[min]]); + + nicheCount[jMin] += 1; + popMask[I[min]] = 1; + k++; + } + else + { + nicheCount[jMin] = 100000; + } + } +} + + +template +template +inline void NSGA3::Associate( + arma::urowvec& refIndex, + arma::Row& dists, + const std::vector& calculatedObjectives, + const std::vector& St) +{ + arma::Mat d(referencePoints.n_cols, St.size()); + ColType zero(arma::size(calculatedObjectives[0]),arma::fill::zeros); + arma::Row temp; + for (size_t i = 0; i < referencePoints.n_cols; i++) + { + PointToLineDistance(temp, calculatedObjectives, St, + zero, referencePoints.col(i)); + d.row(i) = temp; + } + refIndex = arma::index_min(d, 0); + dists = arma::min(d, 0); +} + +template +inline void NSGA3::NicheCount(arma::Row& count, + const arma::urowvec& refIndex, + const std::vector& nextPopulation) +{ + for (size_t i = 0; i < nextPopulation.size(); i++) + { + count[refIndex[i]] += 1; + } +} + +//! Check if a candidate Pareto dominates another candidate. +template +template +inline bool NSGA3::Dominates( + std::vector& calculatedObjectives, + size_t candidateP, + size_t candidateQ) +{ + bool allBetterOrEqual = true; + bool atleastOneBetter = false; + size_t n_objectives = calculatedObjectives[0].n_elem; + + for (size_t i = 0; i < n_objectives; i++) + { + // P is worse than Q for the i-th objective function. + if (calculatedObjectives[candidateP](i) > calculatedObjectives[candidateQ](i)) + allBetterOrEqual = false; + + // P is better than Q for the i-th objective function. + else if (calculatedObjectives[candidateP](i) < calculatedObjectives[candidateQ](i)) + atleastOneBetter = true; + } + + return allBetterOrEqual && atleastOneBetter; +} + +} // namespace ens + +#endif \ No newline at end of file diff --git a/include/ensmallen_bits/problems/dtlz/dtlz1_function.hpp b/include/ensmallen_bits/problems/dtlz/dtlz1_function.hpp index d1ff5223c..46b815de4 100644 --- a/include/ensmallen_bits/problems/dtlz/dtlz1_function.hpp +++ b/include/ensmallen_bits/problems/dtlz/dtlz1_function.hpp @@ -179,10 +179,6 @@ namespace test { { value = value * (1. - coords[stop]); } - else - { - value = value * coords[stop]; - } value = value * (1. + dtlz.g(coords)[0]); return value;