// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_RLs_Hh_
#define DLIB_RLs_Hh_
#include "rls_abstract.h"
#include "../matrix.h"
#include "function.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class rls
{
public:
explicit rls(
double forget_factor_,
double C_ = 1000,
bool apply_forget_factor_to_C_ = false
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < forget_factor_ && forget_factor_ <= 1 &&
0 < C_,
"\t rls::rls()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t forget_factor_: " << forget_factor_
<< "\n\t C_: " << C_
<< "\n\t this: " << this
);
C = C_;
forget_factor = forget_factor_;
apply_forget_factor_to_C = apply_forget_factor_to_C_;
}
rls(
)
{
C = 1000;
forget_factor = 1;
apply_forget_factor_to_C = false;
}
double get_c(
) const
{
return C;
}
double get_forget_factor(
) const
{
return forget_factor;
}
bool should_apply_forget_factor_to_C (
) const
{
return apply_forget_factor_to_C;
}
template <typename EXP>
void train (
const matrix_exp<EXP>& x,
double y
)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_col_vector(x) &&
(get_w().size() == 0 || get_w().size() == x.size()),
"\t void rls::train()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t is_col_vector(x): " << is_col_vector(x)
<< "\n\t x.size(): " << x.size()
<< "\n\t get_w().size(): " << get_w().size()
<< "\n\t this: " << this
);
if (R.size() == 0)
{
R = identity_matrix<double>(x.size())*C;
w.set_size(x.size());
w = 0;
}
// multiply by forget factor and incorporate x*trans(x) into R.
const double l = 1.0/forget_factor;
const double temp = 1 + l*trans(x)*R*x;
tmp = R*x;
R = l*R - l*l*(tmp*trans(tmp))/temp;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive.
if (forget_factor != 1 && !apply_forget_factor_to_C)
add_eye_to_inv(R, (1-forget_factor)/C);
// R should always be symmetric. This line improves numeric stability of this algorithm.
if (cnt%10 == 0)
R = 0.5*(R + trans(R));
++cnt;
w = w + R*x*(y - trans(x)*w);
}
const matrix<double,0,1>& get_w(
) const
{
return w;
}
template <typename EXP>
double operator() (
const matrix_exp<EXP>& x
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_col_vector(x) && get_w().size() == x.size(),
"\t double rls::operator()()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t is_col_vector(x): " << is_col_vector(x)
<< "\n\t x.size(): " << x.size()
<< "\n\t get_w().size(): " << get_w().size()
<< "\n\t this: " << this
);
return dot(x,w);
}
decision_function<linear_kernel<matrix<double,0,1> > > get_decision_function (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(get_w().size() != 0,
"\t decision_function rls::get_decision_function()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t get_w().size(): " << get_w().size()
<< "\n\t this: " << this
);
decision_function<linear_kernel<matrix<double,0,1> > > df;
df.alpha.set_size(1);
df.basis_vectors.set_size(1);
df.b = 0;
df.alpha = 1;
df.basis_vectors(0) = w;
return df;
}
friend inline void serialize(const rls& item, std::ostream& out)
{
int version = 2;
serialize(version, out);
serialize(item.w, out);
serialize(item.R, out);
serialize(item.C, out);
serialize(item.forget_factor, out);
serialize(item.cnt, out);
serialize(item.apply_forget_factor_to_C, out);
}
friend inline void deserialize(rls& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (!(1 <= version && version <= 2))
throw dlib::serialization_error("Unknown version number found while deserializing rls object.");
if (version >= 1)
{
deserialize(item.w, in);
deserialize(item.R, in);
deserialize(item.C, in);
deserialize(item.forget_factor, in);
}
item.cnt = 0;
item.apply_forget_factor_to_C = false;
if (version >= 2)
{
deserialize(item.cnt, in);
deserialize(item.apply_forget_factor_to_C, in);
}
}
private:
void add_eye_to_inv(
matrix<double>& m,
double C
)
/*!
ensures
- Let m == inv(M)
- this function returns inv(M + C*identity_matrix<double>(m.nr()))
!*/
{
for (long r = 0; r < m.nr(); ++r)
{
m = m - colm(m,r)*trans(colm(m,r))/(1/C + m(r,r));
}
}
matrix<double,0,1> w;
matrix<double> R;
double C;
double forget_factor;
int cnt = 0;
bool apply_forget_factor_to_C;
// This object is here only to avoid reallocation during training. It don't
// logically contribute to the state of this object.
matrix<double,0,1> tmp;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_RLs_Hh_