// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ONE_VS_ONE_TRAiNER_H__
#define DLIB_ONE_VS_ONE_TRAiNER_H__
#include "one_vs_one_trainer_abstract.h"
#include "one_vs_one_decision_function.h"
#include <vector>
#include "../unordered_pair.h"
#include "multiclass_tools.h"
#include <sstream>
#include <iostream>
#include "../any.h"
#include <map>
#include <set>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename any_trainer,
typename label_type_ = double
>
class one_vs_one_trainer
{
public:
typedef label_type_ label_type;
typedef typename any_trainer::sample_type sample_type;
typedef typename any_trainer::scalar_type scalar_type;
typedef typename any_trainer::mem_manager_type mem_manager_type;
typedef one_vs_one_decision_function<one_vs_one_trainer> trained_function_type;
one_vs_one_trainer (
) :
verbose(false)
{}
void set_trainer (
const any_trainer& trainer
)
/*!
ensures
- sets the trainer used for all pairs of training
!*/
{
default_trainer = trainer;
trainers.clear();
}
void set_trainer (
const any_trainer& trainer,
const label_type& l1,
const label_type& l2
)
/*!
requires
- l1 != l2
ensures
- sets the trainer used for just the l1 l2 class pair
!*/
{
trainers[make_unordered_pair(l1,l2)] = trainer;
}
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
struct invalid_label : public dlib::error
{
invalid_label(const std::string& msg, const label_type& l1_, const label_type& l2_
) : dlib::error(msg), l1(l1_), l2(l2_) {};
virtual ~invalid_label(
) throw() {}
label_type l1, l2;
};
trained_function_type train (
const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels
) const
{
const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels);
std::vector<sample_type> samples;
std::vector<scalar_type> labels;
typename trained_function_type::binary_function_table dfs;
for (unsigned long i = 0; i < distinct_labels.size(); ++i)
{
for (unsigned long j = i+1; j < distinct_labels.size(); ++j)
{
samples.clear();
labels.clear();
const unordered_pair<label_type> p(distinct_labels[i], distinct_labels[j]);
// pick out the samples corresponding to these two classes
for (unsigned long k = 0; k < all_samples.size(); ++k)
{
if (all_labels[k] == p.first)
{
samples.push_back(all_samples[k]);
labels.push_back(+1);
}
else if (all_labels[k] == p.second)
{
samples.push_back(all_samples[k]);
labels.push_back(-1);
}
}
if (verbose)
{
std::cout << "Training classifier for " << p.first << " vs. " << p.second << std::endl;
}
// now train a binary classifier using the samples we selected
const typename binary_function_table::const_iterator itr = trainers.find(p);
if (itr != trainers.end())
{
dfs[p] = itr->second.train(samples, labels);
}
else if (default_trainer.is_empty() == false)
{
dfs[p] = default_trainer.train(samples, labels);
}
else
{
std::ostringstream sout;
sout << "In one_vs_one_trainer, no trainer registered for the (" << p.first << ", " << p.second << ") label pair.";
throw invalid_label(sout.str(), p.first, p.second);
}
}
}
return trained_function_type(dfs);
}
private:
any_trainer default_trainer;
typedef std::map<unordered_pair<label_type>, any_trainer> binary_function_table;
binary_function_table trainers;
bool verbose;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ONE_VS_ONE_TRAiNER_H__