// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/control.h>
#include <vector>
#include <sstream>
#include <ctime>
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
dlib::logger dlog("test.lspi");
template <bool have_prior>
struct chain_model
{
typedef int state_type;
typedef int action_type; // 0 is move left, 1 is move right
const static bool force_last_weight_to_1 = have_prior;
const static int num_states = 4; // not required in the model interface
matrix<double,8,1> offset;
chain_model()
{
offset =
2.048 ,
2.56 ,
2.048 ,
3.2 ,
2.56 ,
4 ,
3.2,
5 ;
if (!have_prior)
offset = 0;
}
unsigned long num_features(
) const
{
if (have_prior)
return num_states*2 + 1;
else
return num_states*2;
}
action_type find_best_action (
const state_type& state,
const matrix<double,0,1>& w
) const
{
if (w(state*2)+offset(state*2) >= w(state*2+1)+offset(state*2+1))
//if (w(state*2) >= w(state*2+1))
return 0;
else
return 1;
}
void get_features (
const state_type& state,
const action_type& action,
matrix<double,0,1>& feats
) const
{
feats.set_size(num_features());
feats = 0;
feats(state*2 + action) = 1;
if (have_prior)
feats(num_features()-1) = offset(state*2+action);
}
};
void test_lspi_prior1()
{
print_spinner();
typedef process_sample<chain_model<true> > sample_type;
std::vector<sample_type> samples;
samples.push_back(sample_type(0,0,0,0));
samples.push_back(sample_type(0,1,1,0));
samples.push_back(sample_type(1,0,0,0));
samples.push_back(sample_type(1,1,2,0));
samples.push_back(sample_type(2,0,1,0));
samples.push_back(sample_type(2,1,3,0));
samples.push_back(sample_type(3,0,2,0));
samples.push_back(sample_type(3,1,3,1));
lspi<chain_model<true> > trainer;
//trainer.be_verbose();
trainer.set_lambda(0);
policy<chain_model<true> > pol = trainer.train(samples);
dlog << LINFO << pol.get_weights();
matrix<double,0,1> w = pol.get_weights();
DLIB_TEST(pol.get_weights().size() == 9);
DLIB_TEST(w(w.size()-1) == 1);
w(w.size()-1) = 0;
DLIB_TEST_MSG(length(w) < 1e-12, length(w));
dlog << LINFO << "action: " << pol(0);
dlog << LINFO << "action: " << pol(1);
dlog << LINFO << "action: " << pol(2);
dlog << LINFO << "action: " << pol(3);
DLIB_TEST(pol(0) == 1);
DLIB_TEST(pol(1) == 1);
DLIB_TEST(pol(2) == 1);
DLIB_TEST(pol(3) == 1);
}
void test_lspi_prior2()
{
print_spinner();
typedef process_sample<chain_model<true> > sample_type;
std::vector<sample_type> samples;
samples.push_back(sample_type(0,0,0,0));
samples.push_back(sample_type(0,1,1,0));
samples.push_back(sample_type(1,0,0,0));
samples.push_back(sample_type(1,1,2,0));
samples.push_back(sample_type(2,0,1,0));
samples.push_back(sample_type(2,1,3,1));
samples.push_back(sample_type(3,0,2,0));
samples.push_back(sample_type(3,1,3,0));
lspi<chain_model<true> > trainer;
//trainer.be_verbose();
trainer.set_lambda(0);
policy<chain_model<true> > pol = trainer.train(samples);
dlog << LINFO << "action: " << pol(0);
dlog << LINFO << "action: " << pol(1);
dlog << LINFO << "action: " << pol(2);
dlog << LINFO << "action: " << pol(3);
DLIB_TEST(pol(0) == 1);
DLIB_TEST(pol(1) == 1);
DLIB_TEST(pol(2) == 1);
DLIB_TEST(pol(3) == 0);
}
void test_lspi_noprior1()
{
print_spinner();
typedef process_sample<chain_model<false> > sample_type;
std::vector<sample_type> samples;
samples.push_back(sample_type(0,0,0,0));
samples.push_back(sample_type(0,1,1,0));
samples.push_back(sample_type(1,0,0,0));
samples.push_back(sample_type(1,1,2,0));
samples.push_back(sample_type(2,0,1,0));
samples.push_back(sample_type(2,1,3,0));
samples.push_back(sample_type(3,0,2,0));
samples.push_back(sample_type(3,1,3,1));
lspi<chain_model<false> > trainer;
//trainer.be_verbose();
trainer.set_lambda(0.01);
policy<chain_model<false> > pol = trainer.train(samples);
dlog << LINFO << pol.get_weights();
DLIB_TEST(pol.get_weights().size() == 8);
dlog << LINFO << "action: " << pol(0);
dlog << LINFO << "action: " << pol(1);
dlog << LINFO << "action: " << pol(2);
dlog << LINFO << "action: " << pol(3);
DLIB_TEST(pol(0) == 1);
DLIB_TEST(pol(1) == 1);
DLIB_TEST(pol(2) == 1);
DLIB_TEST(pol(3) == 1);
}
void test_lspi_noprior2()
{
print_spinner();
typedef process_sample<chain_model<false> > sample_type;
std::vector<sample_type> samples;
samples.push_back(sample_type(0,0,0,0));
samples.push_back(sample_type(0,1,1,0));
samples.push_back(sample_type(1,0,0,0));
samples.push_back(sample_type(1,1,2,1));
samples.push_back(sample_type(2,0,1,0));
samples.push_back(sample_type(2,1,3,0));
samples.push_back(sample_type(3,0,2,0));
samples.push_back(sample_type(3,1,3,0));
lspi<chain_model<false> > trainer;
//trainer.be_verbose();
trainer.set_lambda(0.01);
policy<chain_model<false> > pol = trainer.train(samples);
dlog << LINFO << pol.get_weights();
DLIB_TEST(pol.get_weights().size() == 8);
dlog << LINFO << "action: " << pol(0);
dlog << LINFO << "action: " << pol(1);
dlog << LINFO << "action: " << pol(2);
dlog << LINFO << "action: " << pol(3);
DLIB_TEST(pol(0) == 1);
DLIB_TEST(pol(1) == 1);
DLIB_TEST(pol(2) == 0);
DLIB_TEST(pol(3) == 0);
}
class lspi_tester : public tester
{
public:
lspi_tester (
) :
tester (
"test_lspi", // the command line argument name for this test
"Run tests on the lspi object.", // the command line argument description
0 // the number of command line arguments for this test
)
{
}
void perform_test (
)
{
test_lspi_prior1();
test_lspi_prior2();
test_lspi_noprior1();
test_lspi_noprior2();
}
};
lspi_tester a;
}