Skip to content

Commit

Permalink
Started optimizer refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
MSallermann committed Nov 16, 2023
1 parent 1d50a8a commit ae0b329
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 121 deletions.
37 changes: 29 additions & 8 deletions include/fc_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@ class FCLayer : public Layer<scalar>
{
protected:
Matrix<scalar> weights;
Vector<scalar> bias;
Matrix<scalar> weights_error;

Matrix<scalar> bias;
Matrix<scalar> bias_error;

public:
FCLayer( size_t input_size, size_t output_size )
: Layer<scalar>( input_size, output_size ),
weights( Matrix<scalar>( output_size, input_size ) ),
bias( Vector<scalar>( output_size ) )
weights_error( Matrix<scalar>( output_size, input_size ) ),
bias( Vector<scalar>( output_size ) ),
bias_error( Vector<scalar>( output_size ) )
{
auto rd = std::random_device();
auto gen = std::mt19937( rd() );
Expand All @@ -39,18 +44,20 @@ class FCLayer : public Layer<scalar>
// returns output for a given input
Matrix<scalar> forward_propagation( const Matrix<scalar> & input_data ) override
{
this->input = input_data;
this->output = ( weights * input_data ).colwise() + bias;
this->input = input_data;
this->output
= ( weights * input_data ).colwise()
+ bias.col(
0 ); // We use .col(0), so the bias can be treated as a matrix with fixed columns at compile time
return this->output;
}

// computes dE/dW, dE/dB for a given output_error=dE/dY. Returns input_error=dE/dX.
Matrix<scalar> backward_propagation( const Matrix<scalar> & output_error ) override
{
auto input_error = weights.transpose() * output_error;
Matrix<scalar> weights_error = output_error * this->input.transpose() / output_error.cols();
Vector<scalar> bias_error = ( output_error ).rowwise().mean();
this->opt->optimize( &weights, &weights_error, &bias, &bias_error );
auto input_error = weights.transpose() * output_error;
weights_error = output_error * this->input.transpose() / output_error.cols();
bias_error = ( output_error ).rowwise().mean();
return input_error;
}

Expand All @@ -60,6 +67,20 @@ class FCLayer : public Layer<scalar>
return this->weights.size() + this->bias.size();
}

// Get ref to trainable parameters
std::vector<Eigen::Ref<Matrix<scalar>>> variables() override
{
return std::vector<Eigen::Ref<Matrix<scalar>>>{ Eigen::Ref<Matrix<scalar>>( weights ),
Eigen::Ref<Matrix<scalar>>( bias ) };
};

// Get ref to trainable parameters
std::vector<Eigen::Ref<Matrix<scalar>>> gradients() override
{
return std::vector<Eigen::Ref<Matrix<scalar>>>{ Eigen::Ref<Matrix<scalar>>( weights_error ),
Eigen::Ref<Matrix<scalar>>( bias_error ) };
};

// Access the current weights
Matrix<scalar> get_weights()
{
Expand Down
32 changes: 18 additions & 14 deletions include/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,10 @@ class Layer
public:
Layer() = default;
Layer( std::optional<size_t> input_size, std::optional<size_t> output_size )
: input_size( input_size ),
output_size( output_size ),
opt( std::make_unique<Optimizers::StochasticGradientDescent<scalar>>( 0.1 ) )
: input_size( input_size ), output_size( output_size )
{
}

std::unique_ptr<Optimizers::Optimizer<scalar>> opt;

// TODO: figure out how to implement copy constructor
// Layer( const Layer & l )
// : input( l.input ), output( l.output ), input_size( l.input_size ), output_size( l.output_size )
// {
// opt = std::make_unique<Optimizer<scalar>>( l.opt );
// }

virtual std::string name() = 0;

std::optional<size_t> get_input_size()
Expand All @@ -60,8 +49,23 @@ class Layer
// computes dE/dX for a given dE/dY (and update parameters if any)
virtual Matrix<scalar> backward_propagation( const Matrix<scalar> & output_error ) = 0;

// Get trainable parameters
virtual size_t get_trainable_params() = 0;
// Get number of trainable parameters
virtual size_t get_trainable_params()
{
return 0;
};

// Get ref to trainable parameters
virtual std::vector<Eigen::Ref<Matrix<scalar>>> variables()
{
return {}; // Standard behaviour is to return an empty vector, i.e. no trainable params
};

// Get ref to gradients of parameters
virtual std::vector<Eigen::Ref<Matrix<scalar>>> gradients()
{
return {};
};

virtual ~Layer() = default;
};
Expand Down
37 changes: 32 additions & 5 deletions include/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class Network
public:
std::optional<scalar> loss_tol = std::nullopt;

std::unique_ptr<Optimizers::Optimizer<scalar>> opt
= std::make_unique<Optimizers::StochasticGradientDescent<scalar>>( 0.00001 );

Network() = default;

template<typename LayerT, typename... T>
Expand Down Expand Up @@ -64,18 +67,39 @@ class Network
return loss;
}

// void set_optimizer( const Optimizers::Optimizer<scalar> & opt )
// {
// this->opt = s
// }

void register_optimizer_variables()
{
this->opt->variables.clear();
this->opt->variables.clear();

for( auto & layer : layers )
{
for( auto & v : layer->variables() )
{
this->opt->variables.push_back( v );
}

for( auto & g : layer->gradients() )
{
this->opt->gradients.push_back( g );
}
}
}

void
fit( const std::vector<Matrix<scalar>> & x_train, const std::vector<Matrix<scalar>> & y_train, size_t epochs,
scalar learning_rate, bool print_progress = false )
{
register_optimizer_variables();

auto n_samples = x_train.size();
auto batch_size = x_train[0].cols();

for( auto & l : layers )
{
l->opt = std::move( std::make_unique<Optimizers::StochasticGradientDescent<scalar>>( learning_rate ) );
}

fmt::print(
"Fitting with {} samples of batchsize {} ({} total)\n\n", n_samples, batch_size, n_samples * batch_size );

Expand All @@ -86,6 +110,7 @@ class Network
{
auto t_epoch_start = std::chrono::high_resolution_clock::now();
err = 0;

for( size_t j = 0; j < n_samples; j++ )
{
// forward propagation
Expand All @@ -104,6 +129,8 @@ class Network
auto & layer = layers[i_layer];
error = layer->backward_propagation( error );
}

opt->optimize();
}

auto t_epoch_end = std::chrono::high_resolution_clock::now();
Expand Down
Loading

0 comments on commit ae0b329

Please sign in to comment.