Skip to content

Commit 5210bc0

Browse files
feat(layer): apply initializers to layers
1 parent 1b8cd13 commit 5210bc0

File tree

10 files changed

+217
-157
lines changed

10 files changed

+217
-157
lines changed

.idea/codeStyles/Project.xml

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

layer/fully_connected.cpp

Lines changed: 105 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -20,190 +20,198 @@ using namespace gsl;
2020

2121
using namespace Eigen;
2222

23-
cppbp::layer::FullyConnected::FullyConnected(size_t len, cppbp::layer::IActivationFunction& af)
24-
: act_func_(&af), len_(len), next_(nullptr)
23+
cppbp::layer::FullyConnected::FullyConnected(size_t len, cppbp::layer::IActivationFunction &af)
24+
: act_func_(&af), len_(len), next_(nullptr)
2525
{
26-
id_ = cppbp::layer::FullyConnected::objects_alive;
26+
id_ = cppbp::layer::FullyConnected::objects_alive;
2727
}
2828

29-
cppbp::layer::ILayer& cppbp::layer::FullyConnected::connect(ILayer& next)
29+
cppbp::layer::ILayer &cppbp::layer::FullyConnected::connect(ILayer &next)
3030
{
31-
next.reshape(len_);
31+
next.reshape(len_);
3232

33-
this->set_next(&next);
34-
next.set_prev(this);
33+
this->set_next(&next);
34+
next.set_prev(this);
3535

36-
return next;
36+
return next;
3737
}
3838

3939
void cppbp::layer::FullyConnected::backprop()
4040
{
41-
VectorXd prev_activation(1 + prev()->get().size());
42-
prev_activation << 1, prev()->get();
41+
VectorXd prev_activation(1 + prev()->get().size());
42+
prev_activation << 1, prev()->get();
4343

44-
auto derives = act_func_->derive(activations_);
45-
deltas_ = act_func_->derive(activations_).transpose() * errors_;
44+
auto derives = act_func_->derive(activations_);
45+
deltas_ = act_func_->derive(activations_).transpose() * errors_;
4646

47-
VectorXd errors = deltas_.transpose() * weights_.block(0, 1, weights_.rows(), weights_.cols() - 1);
47+
VectorXd errors = deltas_.transpose() * weights_.block(0, 1, weights_.rows(), weights_.cols() - 1);
4848

49-
if (prev())
50-
{
51-
prev()->set_errors(errors);
52-
prev()->backprop();
53-
}
49+
if (prev())
50+
{
51+
prev()->set_errors(errors);
52+
prev()->backprop();
53+
}
5454
}
5555

5656
void cppbp::layer::FullyConnected::forward()
5757
{
58-
if (next())
59-
{
60-
next()->set(activations_);
61-
next()->forward();
62-
}
58+
if (next())
59+
{
60+
next()->set(activations_);
61+
next()->forward();
62+
}
6363
}
6464

65-
void cppbp::layer::FullyConnected::optimize(cppbp::optimizer::IOptimizer& opt)
65+
void cppbp::layer::FullyConnected::optimize(cppbp::optimizer::IOptimizer &opt)
6666
{
67-
Expects(!weights_.hasNaN());
67+
Expects(!weights_.hasNaN());
6868

69-
VectorXd aug{1 + prev()->get().size()};
70-
aug << 1, prev()->get();
71-
weights_ = opt.optimize(weights_, deltas_ * aug.transpose());
69+
VectorXd aug{1 + prev()->get().size()};
70+
aug << 1, prev()->get();
71+
weights_ = opt.optimize(weights_, deltas_ * aug.transpose());
7272

73-
if (next_)
74-
{
75-
next_->optimize(opt);
76-
}
73+
if (next_)
74+
{
75+
next_->optimize(opt);
76+
}
7777

78-
Ensures(!weights_.hasNaN());
78+
Ensures(!weights_.hasNaN());
7979
}
8080

8181
void cppbp::layer::FullyConnected::set(VectorXd vec)
8282
{
83-
Expects(!vec.hasNaN());
83+
Expects(!vec.hasNaN());
8484

85-
input_ = vec;
86-
VectorXd aug(vec.size() + 1);
87-
aug << 1, vec;
88-
activations_ = act_func_->eval(weights_ * aug);
85+
input_ = vec;
86+
VectorXd aug(vec.size() + 1);
87+
aug << 1, vec;
88+
activations_ = act_func_->eval(weights_ * aug);
8989

90-
Ensures(!activations_.hasNaN());
90+
Ensures(!activations_.hasNaN());
9191
}
9292

9393
void cppbp::layer::FullyConnected::set_deltas(Eigen::VectorXd dlts)
9494
{
95-
deltas_ = dlts;
95+
deltas_ = dlts;
9696
}
9797

9898
std::string cppbp::layer::FullyConnected::summary() const
9999
{
100-
stringstream ss{};
101-
ss << fmt::format("Fully Connected [{} neurons]:{{\n", len_);
102-
for (const auto& row : weights_.rowwise())
103-
{
104-
ss << fmt::format("[1 Bias, {} weights]=", len_) << row << "\n";// TODO: custom formatter
105-
}
106-
ss << "}";
100+
stringstream ss{};
101+
ss << fmt::format("Fully Connected [{} neurons]:{{\n", len_);
102+
for (const auto &row: weights_.rowwise())
103+
{
104+
ss << fmt::format("[1 Bias, {} weights]=", len_) << row << "\n";// TODO: custom formatter
105+
}
106+
ss << "}";
107107

108-
if (next_)
109-
{
110-
ss << "\n";
111-
ss << next_->summary();
112-
}
108+
if (next_)
109+
{
110+
ss << "\n";
111+
ss << next_->summary();
112+
}
113113

114-
return ss.str();
114+
return ss.str();
115115
}
116116

117117
Eigen::VectorXd cppbp::layer::FullyConnected::get() const
118118
{
119-
return activations_;
119+
return activations_;
120120
}
121121

122122
string cppbp::layer::FullyConnected::name() const
123123
{
124-
return fmt::format("fc {}", id_);
124+
return fmt::format("fc {}", id_);
125125
}
126126

127-
cppbp::layer::ILayer& cppbp::layer::FullyConnected::operator|(cppbp::layer::ILayer& next)
127+
cppbp::layer::ILayer &cppbp::layer::FullyConnected::operator|(cppbp::layer::ILayer &next)
128128
{
129-
return connect(next);
129+
return connect(next);
130130
}
131131

132-
cppbp::layer::ILayer* cppbp::layer::FullyConnected::next()
132+
cppbp::layer::ILayer *cppbp::layer::FullyConnected::next()
133133
{
134-
return next_;
134+
return next_;
135135
}
136136

137-
cppbp::layer::ILayer* cppbp::layer::FullyConnected::prev()
137+
cppbp::layer::ILayer *cppbp::layer::FullyConnected::prev()
138138
{
139-
return prev_;
139+
return prev_;
140140
}
141141

142-
cppbp::layer::IActivationFunction& cppbp::layer::FullyConnected::activation_function()
142+
cppbp::layer::IActivationFunction &cppbp::layer::FullyConnected::activation_function()
143143
{
144-
return *act_func_;
144+
return *act_func_;
145145
}
146146

147147
void cppbp::layer::FullyConnected::reshape(size_t input)
148148
{
149-
if (input == weights_.cols()) return;
149+
if (input == weights_.cols()) return;
150150

151-
weights_ = MatrixXd::Random(len_, input + 1);
151+
if (act_func_)
152+
{
153+
auto initializer = act_func_->default_initializer();
154+
weights_ = initializer->initialize_weights(len_, input + 1, input, len_);
155+
}
156+
else
157+
{
158+
weights_ = MatrixXd::Random(len_, input + 1);
159+
}
152160
}
153161

154162
void cppbp::layer::FullyConnected::set_errors(Eigen::VectorXd errors)
155163
{
156-
errors_ = errors;
164+
errors_ = errors;
157165
}
158166

159-
void cppbp::layer::FullyConnected::set_prev(cppbp::layer::ILayer* prev)
167+
void cppbp::layer::FullyConnected::set_prev(cppbp::layer::ILayer *prev)
160168
{
161-
prev_ = prev;
169+
prev_ = prev;
162170
}
163171

164-
void cppbp::layer::FullyConnected::set_next(cppbp::layer::ILayer* next)
172+
void cppbp::layer::FullyConnected::set_next(cppbp::layer::ILayer *next)
165173
{
166-
next_ = next;
174+
next_ = next;
167175
}
168176

169177
std::tuple<std::shared_ptr<char[]>, size_t> cppbp::layer::FullyConnected::serialize()
170178
{
171-
size_t size = sizeof(LayerDescriptor) + this->weights_.size() * sizeof(double);
179+
size_t size = sizeof(LayerDescriptor) + this->weights_.size() * sizeof(double);
172180

173-
auto ret = make_shared<char[]>(size);
181+
auto ret = make_shared<char[]>(size);
174182

175-
auto desc = reinterpret_cast<LayerDescriptor*>(ret.get());
183+
auto desc = reinterpret_cast<LayerDescriptor *>(ret.get());
176184

177-
desc->type = LayerTypeId<FullyConnected>::value;
178-
desc->act_func = act_func_->type_id();
179-
desc->rows = len_;
180-
desc->cols = weights_.cols();
185+
desc->type = LayerTypeId<FullyConnected>::value;
186+
desc->act_func = act_func_->type_id();
187+
desc->rows = len_;
188+
desc->cols = weights_.cols();
181189

182-
auto w = reinterpret_cast<double*>(ret.get() + sizeof(LayerDescriptor));
183-
for (int i = 0; i < weights_.size(); i++)
184-
{
185-
*(w++) = weights_.coeff(i);
186-
}
190+
auto w = reinterpret_cast<double *>(ret.get() + sizeof(LayerDescriptor));
191+
for (int i = 0; i < weights_.size(); i++)
192+
{
193+
*(w++) = weights_.coeff(i);
194+
}
187195

188-
return make_tuple(ret, size);
196+
return make_tuple(ret, size);
189197
}
190198

191-
char* cppbp::layer::FullyConnected::deserialize(char* data)
199+
char *cppbp::layer::FullyConnected::deserialize(char *data)
192200
{
193-
auto desc = reinterpret_cast<LayerDescriptor*>(data);
194-
data += sizeof(LayerDescriptor);
201+
auto desc = reinterpret_cast<LayerDescriptor *>(data);
202+
data += sizeof(LayerDescriptor);
195203

196-
// TODO: restore information
197-
restored_act_func_ = ActivationFunctionFactory::from_id(desc->act_func);
198-
act_func_ = restored_act_func_.get();
204+
// TODO: restore information
205+
restored_act_func_ = ActivationFunctionFactory::from_id(desc->act_func);
206+
act_func_ = restored_act_func_.get();
199207

200-
len_ = desc->rows;
201-
reshape(desc->cols);
202-
auto w = reinterpret_cast<double*>(data);
203-
for (int i = 0; i < desc->rows * desc->cols; i++)
204-
{
205-
weights_.coeffRef(i) = *(w++);
206-
}
208+
len_ = desc->rows;
209+
reshape(desc->cols);
210+
auto w = reinterpret_cast<double *>(data);
211+
for (int i = 0; i < desc->rows * desc->cols; i++)
212+
{
213+
weights_.coeffRef(i) = *(w++);
214+
}
207215

208-
return data;
216+
return data;
209217
}

layer/include/layer/activation_function.h

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,48 @@
55

66
#include <base/type_id.h>
77

8+
#include <layer/initializer.h>
9+
810
#include <Eigen/Eigen>
911

1012
#include <memory>
1113

1214
namespace cppbp::layer
1315
{
1416
class IActivationFunction
15-
: public base::ITypeId
17+
: public base::ITypeId
1618
{
17-
public:
18-
virtual double operator()(double x) = 0;
19+
public:
20+
virtual double operator()(double x) = 0;
21+
22+
virtual double eval(double x) = 0;
23+
24+
virtual double derive(double y) = 0;
1925

20-
virtual double eval(double x) = 0;
21-
virtual double derive(double y) = 0;
26+
virtual Eigen::VectorXd eval(Eigen::VectorXd x) = 0;
2227

23-
virtual Eigen::VectorXd eval(Eigen::VectorXd x) = 0;
24-
virtual Eigen::MatrixXd derive(Eigen::VectorXd y) = 0;
28+
virtual Eigen::MatrixXd derive(Eigen::VectorXd y) = 0;
2529

30+
virtual std::shared_ptr<IWeightInitializer> default_initializer() = 0;
2631
};
2732

2833
class ActivationFunctionFactory final
2934
{
30-
public:
31-
static std::shared_ptr<cppbp::layer::IActivationFunction> from_id(uint32_t id);
32-
33-
template<typename Callback>
34-
static std::shared_ptr<cppbp::layer::IActivationFunction> from_id(uint32_t id, Callback cbk)
35-
{
36-
try
37-
{
38-
from_id(id);
39-
}
40-
catch (...)
41-
{
42-
return cbk(id);
43-
}
44-
}
35+
public:
36+
static std::shared_ptr<cppbp::layer::IActivationFunction> from_id(uint32_t id);
37+
38+
template<typename Callback>
39+
static std::shared_ptr<cppbp::layer::IActivationFunction> from_id(uint32_t id, Callback cbk)
40+
{
41+
try
42+
{
43+
from_id(id);
44+
}
45+
catch (...)
46+
{
47+
return cbk(id);
48+
}
49+
}
4550

4651
};
4752

layer/include/layer/initializer.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,20 @@
66

77
#include <base/base.h>
88

9+
#include <concepts>
910
#include <memory>
1011

1112
namespace cppbp::layer
1213
{
1314
class IWeightInitializer
1415
{
15-
public:
16-
virtual base::MatrixType initialize_weights(size_t rows, size_t cols, size_t ni, size_t no) = 0;
16+
public:
17+
virtual base::MatrixType initialize_weights(size_t rows, size_t cols, size_t ni, size_t no) = 0;
18+
19+
template<std::derived_from<IWeightInitializer> T, typename...TArgs>
20+
static inline std::shared_ptr<IWeightInitializer> make(TArgs &&... args)
21+
{
22+
return std::make_shared<T>(std::forward<TArgs>(args)...);
23+
}
1724
};
1825
}// namespace cppbp::layer

0 commit comments

Comments
 (0)