Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
RoberLopez committed Jan 16, 2025
1 parent 8839ba6 commit cc09ea4
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 135 deletions.
6 changes: 2 additions & 4 deletions opennn/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ void Batch::print() const
<< "Input dimensions:" << endl;

print_vector(input_dimensions);


/*

if(input_dimensions.size() == 4)
{
const TensorMap<Tensor<type, 4>> inputs((type*)input_tensor.data(),
Expand All @@ -195,7 +193,7 @@ void Batch::print() const

cout << inputs << endl;
}
*/


cout << "Decoder:" << endl
<< "Decoder dimensions:" << endl;
Expand Down
3 changes: 0 additions & 3 deletions opennn/language_data_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ void LanguageDataSet::set_target_vocabulary(const unordered_map<string, Index>&
void LanguageDataSet::set_data_random()
{
/*
data_path.clear();
set(batch_samples_number, decoder_length + 2 * completion_length);
for(Index i = 0; i < batch_samples_number; i++)
{
Expand Down
73 changes: 63 additions & 10 deletions opennn/learning_rate_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,10 @@ void LearningRateAlgorithm::set(LossIndex* new_loss_index)

void LearningRateAlgorithm::set_default()
{
/*
delete thread_pool;
delete thread_pool_device;
const unsigned int threads_number = thread::hardware_concurrency();
thread_pool = new ThreadPool(n);
thread_pool_device = new ThreadPoolDevice(thread_pool, n);
const unsigned int threads_number = thread::hardware_concurrency();

thread_pool = make_unique<ThreadPool>(threads_number);
thread_pool_device = make_unique<ThreadPoolDevice>(thread_pool.get(), threads_number);
*/
// TRAINING OPERATORS

learning_rate_method = LearningRateMethod::BrentMethod;

Expand Down Expand Up @@ -446,6 +436,69 @@ void LearningRateAlgorithm::from_XML(const XMLDocument& document)
set_display(read_xml_bool(root_element, "Display"));
}


LearningRateAlgorithm::Triplet::Triplet()
{
A = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
U = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
B = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
}


type LearningRateAlgorithm::Triplet::get_length() const
{
return abs(B.first - A.first);
}


pair<type, type> LearningRateAlgorithm::Triplet::minimum() const
{
Tensor<type, 1> losses(3);

losses.setValues({ A.second, U.second, B.second });

const Index minimal_index = opennn::minimal_index(losses);

if (minimal_index == 0) return A;
else if (minimal_index == 1) return U;
else return B;
}


string LearningRateAlgorithm::Triplet::struct_to_string() const
{
ostringstream buffer;

buffer << "A = (" << A.first << "," << A.second << ")\n"
<< "U = (" << U.first << "," << U.second << ")\n"
<< "B = (" << B.first << "," << B.second << ")" << endl;

return buffer.str();
}


void LearningRateAlgorithm::Triplet::print() const
{
cout << struct_to_string()
<< "Lenght: " << get_length() << endl;
}


void LearningRateAlgorithm::Triplet::check() const
{
if (U.first < A.first)
throw runtime_error("U is less than A:\n" + struct_to_string());

if (U.first > B.first)
throw runtime_error("U is greater than B:\n" + struct_to_string());

if (U.second >= A.second)
throw runtime_error("fU is equal or greater than fA:\n" + struct_to_string());

if (U.second >= B.second)
throw runtime_error("fU is equal or greater than fB:\n" + struct_to_string());
}

}


Expand Down
89 changes: 7 additions & 82 deletions opennn/learning_rate_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,9 @@ class LearningRateAlgorithm

struct Triplet
{
Triplet()
{
A = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
U = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
B = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
}
Triplet();

inline bool operator == (const Triplet& other_triplet) const
bool operator == (const Triplet& other_triplet) const
{
if(A == other_triplet.A
&& U == other_triplet.U
Expand All @@ -44,59 +39,15 @@ class LearningRateAlgorithm
return false;
}

inline type get_length() const
{
return abs(B.first - A.first);
}

type get_length() const;

inline pair<type, type> minimum() const
{
Tensor<type, 1> losses(3);

losses.setValues({A.second, U.second, B.second});
pair<type, type> minimum() const;

const Index minimal_index = opennn::minimal_index(losses);
string struct_to_string() const;

if(minimal_index == 0) return A;
else if(minimal_index == 1) return U;
else return B;
}
void print() const;


inline string struct_to_string() const
{
ostringstream buffer;

buffer << "A = (" << A.first << "," << A.second << ")\n"
<< "U = (" << U.first << "," << U.second << ")\n"
<< "B = (" << B.first << "," << B.second << ")" << endl;

return buffer.str();
}


inline void print() const
{
cout << struct_to_string()
<< "Lenght: " << get_length() << endl;
}


inline void check() const
{
if(U.first < A.first)
throw runtime_error("U is less than A:\n" + struct_to_string());

if(U.first > B.first)
throw runtime_error("U is greater than B:\n" + struct_to_string());

if(U.second >= A.second)
throw runtime_error("fU is equal or greater than fA:\n" + struct_to_string());

if(U.second >= B.second)
throw runtime_error("fU is equal or greater than fB:\n" + struct_to_string());
}
void check() const;

pair<type, type> A;

Expand All @@ -105,49 +56,31 @@ class LearningRateAlgorithm
pair<type, type> B;
};

// Get

LossIndex* get_loss_index() const;

bool has_loss_index() const;

// Training operators

const LearningRateMethod& get_learning_rate_method() const;
string write_learning_rate_method() const;

// Training parameters

const type& get_learning_rate_tolerance() const;

// Utilities

const bool& get_display() const;

// Set

void set(LossIndex* = nullptr);

void set_loss_index(LossIndex*);
void set_threads_number(const int&);

// Training operators

void set_learning_rate_method(const LearningRateMethod&);
void set_learning_rate_method(const string&);

// Training parameters

void set_learning_rate_tolerance(const type&);

// Utilities

void set_display(const bool&);

void set_default();

// Learning rate

type calculate_golden_section_learning_rate(const Triplet&) const;
type calculate_Brent_method_learning_rate(const Triplet&) const;

Expand All @@ -160,29 +93,21 @@ class LearningRateAlgorithm
ForwardPropagation&,
BackPropagation&,
OptimizationAlgorithmData&) const;

// Serialization

void from_XML(const XMLDocument&);

void to_XML(XMLPrinter&) const;

private:

// FIELDS

LossIndex* loss_index = nullptr;

// TRAINING OPERATORS

LearningRateMethod learning_rate_method;

type learning_rate_tolerance;

type loss_tolerance;

// UTILITIES

bool display = true;

const type golden_ratio = type(1.618);
Expand Down
41 changes: 5 additions & 36 deletions opennn/strings_utilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,46 +352,15 @@ void replace_double_char_with_label(string &str, const string &target_char, cons

void replace_substring_within_quotes(string &str, const string &target, const string &replacement)
{
/*
regex r("\"([^\"]*)\"");
string result;
smatch match;
string prefix = str;
string::const_iterator search_start(str.begin());
while(regex_search(prefix, match, r))
{
string match_str = match.str();
string replaced_str = match_str;
size_t position = 0;
while((position = replaced_str.find(target, position)) != string::npos)
{
replaced_str.replace(position, target.length(), replacement);
position += replacement.length();
}
result += match.prefix().str() + replaced_str;
prefix = match.suffix().str();
}
result += prefix;
str = result;
*/

regex r("\"([^\"]*)\"");
string result;
string::const_iterator search_start(str.begin());
smatch match;

while (regex_search(search_start, str.cend(), match, r))
{
result += string(search_start, match[0].first); // Append text before match
string quoted_content = match[1].str(); // Extract quoted content
result += string(search_start, match[0].first);
string quoted_content = match[1].str();

size_t position = 0;
while ((position = quoted_content.find(target, position)) != string::npos)
Expand All @@ -400,11 +369,11 @@ void replace_substring_within_quotes(string &str, const string &target, const st
position += replacement.length();
}

result += "\"" + quoted_content + "\""; // Append updated quoted content
search_start = match[0].second; // Move search start past the current match
result += "\"" + quoted_content + "\"";
search_start = match[0].second;
}

result += string(search_start, str.cend()); // Append remaining text
result += string(search_start, str.cend());
str = result;
}

Expand Down

0 comments on commit cc09ea4

Please sign in to comment.