// Copyright (C) 2007 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SVm_FUNCTION
#define DLIB_SVm_FUNCTION
#include "function_abstract.h"
#include <cmath>
#include <limits>
#include <sstream>
#include "../matrix.h"
#include "../algs.h"
#include "../serialize.h"
#include "../rand.h"
#include "../statistics.h"
#include "kernel_matrix.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename K
>
struct decision_function
{
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
scalar_vector_type alpha;
scalar_type b;
K kernel_function;
sample_vector_type basis_vectors;
decision_function (
) : b(0), kernel_function(K()) {}
decision_function (
const decision_function& d
) :
alpha(d.alpha),
b(d.b),
kernel_function(d.kernel_function),
basis_vectors(d.basis_vectors)
{}
decision_function (
const scalar_vector_type& alpha_,
const scalar_type& b_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) :
alpha(alpha_),
b(b_),
kernel_function(kernel_function_),
basis_vectors(basis_vectors_)
{}
decision_function& operator= (
const decision_function& d
)
{
if (this != &d)
{
alpha = d.alpha;
b = d.b;
kernel_function = d.kernel_function;
basis_vectors = d.basis_vectors;
}
return *this;
}
scalar_type operator() (
const sample_type& x
) const
{
scalar_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,basis_vectors(i));
return temp - b;
}
};
template <
typename K
>
void serialize (
const decision_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.b, out);
serialize(item.kernel_function, out);
serialize(item.basis_vectors, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type decision_function");
}
}
template <
typename K
>
void deserialize (
decision_function<K>& item,
std::istream& in
)
{
try
{
deserialize(item.alpha, in);
deserialize(item.b, in);
deserialize(item.kernel_function, in);
deserialize(item.basis_vectors, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type decision_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename function_type
>
struct probabilistic_function
{
typedef typename function_type::scalar_type scalar_type;
typedef typename function_type::sample_type sample_type;
typedef typename function_type::mem_manager_type mem_manager_type;
scalar_type alpha;
scalar_type beta;
function_type decision_funct;
probabilistic_function (
) : alpha(0), beta(0), decision_funct(function_type()) {}
probabilistic_function (
const probabilistic_function& d
) :
alpha(d.alpha),
beta(d.beta),
decision_funct(d.decision_funct)
{}
probabilistic_function (
const scalar_type a_,
const scalar_type b_,
const function_type& decision_funct_
) :
alpha(a_),
beta(b_),
decision_funct(decision_funct_)
{}
probabilistic_function& operator= (
const probabilistic_function& d
)
{
if (this != &d)
{
alpha = d.alpha;
beta = d.beta;
decision_funct = d.decision_funct;
}
return *this;
}
scalar_type operator() (
const sample_type& x
) const
{
scalar_type f = decision_funct(x);
return 1/(1 + std::exp(alpha*f + beta));
}
};
template <
typename function_type
>
void serialize (
const probabilistic_function<function_type>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.beta, out);
serialize(item.decision_funct, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type probabilistic_function");
}
}
template <
typename function_type
>
void deserialize (
probabilistic_function<function_type>& item,
std::istream& in
)
{
try
{
deserialize(item.alpha, in);
deserialize(item.beta, in);
deserialize(item.decision_funct, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type probabilistic_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename K
>
struct probabilistic_decision_function
{
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
scalar_type alpha;
scalar_type beta;
decision_function<K> decision_funct;
probabilistic_decision_function (
) : alpha(0), beta(0), decision_funct(decision_function<K>()) {}
probabilistic_decision_function (
const probabilistic_function<decision_function<K> >& d
) :
alpha(d.alpha),
beta(d.beta),
decision_funct(d.decision_funct)
{}
probabilistic_decision_function (
const probabilistic_decision_function& d
) :
alpha(d.alpha),
beta(d.beta),
decision_funct(d.decision_funct)
{}
probabilistic_decision_function (
const scalar_type a_,
const scalar_type b_,
const decision_function<K>& decision_funct_
) :
alpha(a_),
beta(b_),
decision_funct(decision_funct_)
{}
probabilistic_decision_function& operator= (
const probabilistic_decision_function& d
)
{
if (this != &d)
{
alpha = d.alpha;
beta = d.beta;
decision_funct = d.decision_funct;
}
return *this;
}
scalar_type operator() (
const sample_type& x
) const
{
scalar_type f = decision_funct(x);
return 1/(1 + std::exp(alpha*f + beta));
}
};
template <
typename K
>
void serialize (
const probabilistic_decision_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.beta, out);
serialize(item.decision_funct, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type probabilistic_decision_function");
}
}
template <
typename K
>
void deserialize (
probabilistic_decision_function<K>& item,
std::istream& in
)
{
try
{
deserialize(item.alpha, in);
deserialize(item.beta, in);
deserialize(item.decision_funct, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type probabilistic_decision_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename K
>
struct distance_function
{
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
scalar_vector_type alpha;
scalar_type b;
K kernel_function;
sample_vector_type basis_vectors;
distance_function (
) : b(0), kernel_function(K()) {}
distance_function (
const distance_function& d
) :
alpha(d.alpha),
b(d.b),
kernel_function(d.kernel_function),
basis_vectors(d.basis_vectors)
{}
distance_function (
const scalar_vector_type& alpha_,
const scalar_type& b_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) :
alpha(alpha_),
b(b_),
kernel_function(kernel_function_),
basis_vectors(basis_vectors_)
{}
distance_function& operator= (
const distance_function& d
)
{
if (this != &d)
{
alpha = d.alpha;
b = d.b;
kernel_function = d.kernel_function;
basis_vectors = d.basis_vectors;
}
return *this;
}
scalar_type operator() (
const sample_type& x
) const
{
scalar_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,basis_vectors(i));
temp = b + kernel_function(x,x) - 2*temp;
if (temp > 0)
return std::sqrt(temp);
else
return 0;
}
scalar_type operator() (
const distance_function& x
) const
{
scalar_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
for (long j = 0; j < x.alpha.nr(); ++j)
temp += alpha(i)*x.alpha(j) * kernel_function(basis_vectors(i), x.basis_vectors(j));
temp = b + x.b - 2*temp;
if (temp > 0)
return std::sqrt(temp);
else
return 0;
}
};
template <
typename K
>
void serialize (
const distance_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.b, out);
serialize(item.kernel_function, out);
serialize(item.basis_vectors, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type distance_function");
}
}
template <
typename K
>
void deserialize (
distance_function<K>& item,
std::istream& in
)
{
try
{
deserialize(item.alpha, in);
deserialize(item.b, in);
deserialize(item.kernel_function, in);
deserialize(item.basis_vectors, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type distance_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename function_type,
typename normalizer_type = vector_normalizer<typename function_type::sample_type>
>
struct normalized_function
{
typedef typename function_type::scalar_type scalar_type;
typedef typename function_type::sample_type sample_type;
typedef typename function_type::mem_manager_type mem_manager_type;
normalizer_type normalizer;
function_type function;
normalized_function (
){}
normalized_function (
const normalized_function& f
) :
normalizer(f.normalizer),
function(f.function)
{}
normalized_function (
const vector_normalizer<sample_type>& normalizer_,
const function_type& funct
) : normalizer(normalizer_), function(funct) {}
scalar_type operator() (
const sample_type& x
) const { return function(normalizer(x)); }
};
template <
typename function_type,
typename normalizer_type
>
void serialize (
const normalized_function<function_type,normalizer_type>& item,
std::ostream& out
)
{
try
{
serialize(item.normalizer, out);
serialize(item.function, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type normalized_function");
}
}
template <
typename function_type,
typename normalizer_type
>
void deserialize (
normalized_function<function_type,normalizer_type>& item,
std::istream& in
)
{
try
{
deserialize(item.normalizer, in);
deserialize(item.function, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type normalized_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename K
>
struct projection_function
{
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
scalar_matrix_type weights;
K kernel_function;
sample_vector_type basis_vectors;
projection_function (
) {}
projection_function (
const projection_function& f
) : weights(f.weights), kernel_function(f.kernel_function), basis_vectors(f.basis_vectors) {}
projection_function (
const scalar_matrix_type& weights_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) : weights(weights_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {}
long out_vector_size (
) const { return weights.nr(); }
const scalar_vector_type& operator() (
const sample_type& x
) const
{
// Run the x sample through all the basis functions we have and then
// multiply it by the weights matrix and return the result. Note that
// the temp vectors are here to avoid reallocating their memory every
// time this function is called.
temp1 = kernel_matrix(kernel_function, basis_vectors, x);
temp2 = weights*temp1;
return temp2;
}
private:
mutable scalar_vector_type temp1, temp2;
};
template <
typename K
>
void serialize (
const projection_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.weights, out);
serialize(item.kernel_function, out);
serialize(item.basis_vectors, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type projection_function");
}
}
template <
typename K
>
void deserialize (
projection_function<K>& item,
std::istream& in
)
{
try
{
deserialize(item.weights, in);
deserialize(item.kernel_function, in);
deserialize(item.basis_vectors, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type projection_function");
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SVm_FUNCTION