// Copyright (C) 2009 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_TYPE_SAFE_UNIOn_h_
#define DLIB_TYPE_SAFE_UNIOn_h_
#include "type_safe_union_kernel_abstract.h"
#include <new>
#include <iostream>
#include <type_traits>
#include <functional>
#include "../serialize.h"
#include "../invoke.h"
namespace dlib
{
// ---------------------------------------------------------------------
class bad_type_safe_union_cast : public std::bad_cast
{
public:
virtual const char * what() const throw()
{
return "bad_type_safe_union_cast";
}
};
// ---------------------------------------------------------------------
template<typename T>
struct in_place_tag { using type = T;};
// ---------------------------------------------------------------------
template <typename... Types> class type_safe_union;
template<typename Tsu>
struct type_safe_union_size;
template<typename... Types>
struct type_safe_union_size<type_safe_union<Types...>> : std::integral_constant<size_t, sizeof...(Types)> {};
template<typename Tsu> struct type_safe_union_size<const Tsu> : type_safe_union_size<Tsu> {};
template<typename Tsu> struct type_safe_union_size<volatile Tsu> : type_safe_union_size<Tsu> {};
template<typename Tsu> struct type_safe_union_size<const volatile Tsu> : type_safe_union_size<Tsu> {};
// ---------------------------------------------------------------------
namespace detail
{
template<size_t I, typename... Ts>
struct nth_type;
template<size_t I, typename T0, typename... Ts>
struct nth_type<I, T0, Ts...> : nth_type<I-1, Ts...> {};
template<typename T0, typename... Ts>
struct nth_type<0, T0, Ts...> { using type = T0; };
}
template <size_t I, typename TSU>
struct type_safe_union_alternative;
template <size_t I, typename... Types>
struct type_safe_union_alternative<I, type_safe_union<Types...>> : detail::nth_type<I, Types...>{};
template<size_t I, typename TSU>
using type_safe_union_alternative_t = typename type_safe_union_alternative<I, TSU>::type;
template <size_t I, typename TSU>
struct type_safe_union_alternative<I, const TSU>
{ using type = typename std::add_const<type_safe_union_alternative_t<I, TSU>>::type; };
template <size_t I, typename TSU>
struct type_safe_union_alternative<I, volatile TSU>
{ using type = typename std::add_volatile<type_safe_union_alternative_t<I, TSU>>::type; };
template <size_t I, typename TSU>
struct type_safe_union_alternative<I, const volatile TSU>
{ using type = typename std::add_cv<type_safe_union_alternative_t<I, TSU>>::type; };
// ---------------------------------------------------------------------
namespace detail
{
// ---------------------------------------------------------------------
template <typename T, typename First, typename... Rest>
struct is_any : std::integral_constant<bool, is_any<T,First>::value || is_any<T,Rest...>::value> {};
template <typename T, typename First>
struct is_any<T,First> : std::is_same<T,First> {};
// ---------------------------------------------------------------------
template <int nTs, typename T, typename... Ts>
struct type_safe_union_type_id_impl
: std::integral_constant<int, -1 - nTs> {};
template <int nTs, typename T, typename T0, typename... Ts>
struct type_safe_union_type_id_impl<nTs, T, T0, Ts...>
: std::integral_constant<int, std::is_same<T,T0>::value ? 1 : type_safe_union_type_id_impl<nTs, T,Ts...>::value + 1> {};
template <typename T, typename... Ts>
struct type_safe_union_type_id : type_safe_union_type_id_impl<sizeof...(Ts),T,Ts...>{};
template <typename T, typename... Ts>
struct type_safe_union_type_id<in_place_tag<T>, Ts...> : type_safe_union_type_id<T,Ts...>{};
// ---------------------------------------------------------------------
}
template <typename... Types>
class type_safe_union
{
/*!
CONVENTION
- is_empty() == (type_identity == 0)
- contains<T>() == (type_identity == get_type_id<T>())
- mem == the aligned block of memory on the stack which is
where objects in the union are stored
!*/
public:
template <typename T>
static constexpr int get_type_id ()
{
return detail::type_safe_union_type_id<T,Types...>::value;
}
private:
template<typename T>
struct is_valid : detail::is_any<T,Types...> {};
template<typename T>
using is_valid_check = typename std::enable_if<is_valid<T>::value, bool>::type;
template <size_t I>
using get_type_t = type_safe_union_alternative_t<I, type_safe_union>;
typename std::aligned_union<0, Types...>::type mem;
int type_identity = 0;
template<
typename F,
typename TSU,
std::size_t I
>
static void apply_to_contents_as_type(
F&& f,
TSU&& me
)
{
std::forward<F>(f)(me.template unchecked_get<get_type_t<I>>());
}
template<
typename F,
typename TSU,
std::size_t... I
>
static void apply_to_contents_impl(
F&& f,
TSU&& me,
dlib::index_sequence<I...>
)
{
using func_t = void(*)(F&&, TSU&&);
const func_t vtable[] = {
/*! Empty (type_identity == 0) case !*/
[](F&&, TSU&&) {
},
/*! Non-empty cases !*/
&apply_to_contents_as_type<F&&,TSU&&,I>...
};
return vtable[me.get_current_type_id()](std::forward<F>(f), std::forward<TSU>(me));
}
template <typename T>
const T& unchecked_get() const
{
return *reinterpret_cast<const T*>(&mem);
}
template <typename T>
T& unchecked_get()
{
return *reinterpret_cast<T*>(&mem);
}
struct destruct_helper
{
template <typename T>
void operator() (T& item) const
{
item.~T();
}
};
void destruct ()
{
apply_to_contents(destruct_helper{});
type_identity = 0;
}
template <typename T, typename... Args>
void construct (
Args&&... args
)
{
destruct();
new(&mem) T(std::forward<Args>(args)...);
type_identity = get_type_id<T>();
}
struct assign_to
{
/*!
This class assigns an object to `me` using std::forward.
!*/
assign_to(type_safe_union& me) : _me(me) {}
template<typename T>
void operator()(T&& x)
{
using U = typename std::decay<T>::type;
if (_me.type_identity != get_type_id<U>())
{
_me.construct<U>(std::forward<T>(x));
}
else
{
_me.template unchecked_get<U>() = std::forward<T>(x);
}
}
type_safe_union& _me;
};
struct move_to
{
/*!
This class move assigns an object to `me`.
!*/
move_to(type_safe_union& me) : _me(me) {}
template<typename T>
void operator()(T& x)
{
if (_me.type_identity != get_type_id<T>())
{
_me.construct<T>(std::move(x));
}
else
{
_me.template unchecked_get<T>() = std::move(x);
}
}
type_safe_union& _me;
};
struct swap_to
{
/*!
This class swaps an object with `me`.
!*/
swap_to(type_safe_union& me) : _me(me) {}
template<typename T>
void operator()(T& x)
/*!
requires
- _me.contains<T>() == true
!*/
{
using std::swap;
swap(_me.unchecked_get<T>(), x);
}
type_safe_union& _me;
};
public:
type_safe_union() = default;
type_safe_union (
const type_safe_union& item
) : type_safe_union()
{
item.apply_to_contents(assign_to{*this});
}
type_safe_union& operator=(
const type_safe_union& item
)
{
if (item.is_empty())
destruct();
else
item.apply_to_contents(assign_to{*this});
return *this;
}
type_safe_union (
type_safe_union&& item
) : type_safe_union()
{
item.apply_to_contents(move_to{*this});
item.destruct();
}
type_safe_union& operator= (
type_safe_union&& item
)
{
if (item.is_empty())
{
destruct();
}
else
{
item.apply_to_contents(move_to{*this});
item.destruct();
}
return *this;
}
template <
typename T,
is_valid_check<typename std::decay<T>::type> = true
>
type_safe_union (
T&& item
) : type_safe_union()
{
assign_to{*this}(std::forward<T>(item));
}
template <
typename T,
is_valid_check<typename std::decay<T>::type> = true
>
type_safe_union& operator= (
T&& item
)
{
assign_to{*this}(std::forward<T>(item));
return *this;
}
template <
typename T,
typename... Args,
is_valid_check<T> = true
>
type_safe_union (
in_place_tag<T>,
Args&&... args
)
{
construct<T>(std::forward<Args>(args)...);
}
~type_safe_union()
{
destruct();
}
void clear()
{
destruct();
}
template <
typename T,
typename... Args,
is_valid_check<T> = true
>
void emplace(
Args&&... args
)
{
construct<T>(std::forward<Args>(args)...);
}
template <typename F>
void apply_to_contents(
F&& f
)
{
apply_to_contents_impl(std::forward<F>(f), *this, dlib::make_index_sequence<sizeof...(Types)>{});
}
template <typename F>
void apply_to_contents(
F&& f
) const
{
apply_to_contents_impl(std::forward<F>(f), *this, dlib::make_index_sequence<sizeof...(Types)>{});
}
template <typename T>
bool contains (
) const
{
return type_identity == get_type_id<T>();
}
bool is_empty (
) const
{
return type_identity == 0;
}
int get_current_type_id() const
{
return type_identity;
}
void swap (
type_safe_union& item
)
{
if (type_identity == item.type_identity)
{
item.apply_to_contents(swap_to{*this});
}
else if (is_empty())
{
item.apply_to_contents(move_to{*this});
item.destruct();
}
else if (item.is_empty())
{
apply_to_contents(move_to{item});
destruct();
}
else
{
type_safe_union tmp;
swap(tmp); // this -> tmp
swap(item); // item -> this
tmp.swap(item); // tmp (this) -> item
}
}
template <
typename T,
is_valid_check<T> = true
>
T& get(
)
{
if (type_identity != get_type_id<T>())
construct<T>();
return unchecked_get<T>();
}
template <
typename T
>
T& get(
in_place_tag<T>
)
{
return get<T>();
}
template <
typename T,
is_valid_check<T> = true
>
const T& cast_to (
) const
{
if (contains<T>())
return unchecked_get<T>();
else
throw bad_type_safe_union_cast();
}
template <
typename T,
is_valid_check<T> = true
>
T& cast_to (
)
{
if (contains<T>())
return unchecked_get<T>();
else
throw bad_type_safe_union_cast();
}
};
template <typename ...Types>
inline void swap (
type_safe_union<Types...>& a,
type_safe_union<Types...>& b
) { a.swap(b); }
namespace detail
{
template<
typename F,
typename TSU,
std::size_t... I
>
void for_each_type_impl(
F&& f,
TSU&& tsu,
dlib::index_sequence<I...>
)
{
using Tsu = typename std::decay<TSU>::type;
(void)std::initializer_list<int>{
(std::forward<F>(f)(
in_place_tag<type_safe_union_alternative_t<I, Tsu>>{},
std::forward<TSU>(tsu)),
0
)...
};
}
template<
typename R,
typename F,
typename TSU,
std::size_t I
>
R visit_impl_as_type(
F&& f,
TSU&& tsu
)
{
using Tsu = typename std::decay<TSU>::type;
using T = type_safe_union_alternative_t<I, Tsu>;
return dlib::invoke(std::forward<F>(f), tsu.template cast_to<T>());
}
template<
typename R,
typename F,
typename TSU,
std::size_t... I
>
R visit_impl(
F&& f,
TSU&& tsu,
dlib::index_sequence<I...>
)
{
using func_t = R(*)(F&&, TSU&&);
const func_t vtable[] = {
/*! Empty (type_identity == 0) case !*/
[](F&&, TSU&&) {
return R();
},
/*! Non-empty cases !*/
&visit_impl_as_type<R,F&&,TSU&&,I>...
};
return vtable[tsu.get_current_type_id()](std::forward<F>(f), std::forward<TSU>(tsu));
}
}
template<
typename TSU,
typename F
>
void for_each_type(
F&& f,
TSU&& tsu
)
{
using Tsu = typename std::decay<TSU>::type;
static constexpr std::size_t Size = type_safe_union_size<Tsu>::value;
detail::for_each_type_impl(std::forward<F>(f), std::forward<TSU>(tsu), dlib::make_index_sequence<Size>{});
}
template<
typename F,
typename TSU,
typename Tsu = typename std::decay<TSU>::type,
typename T0 = type_safe_union_alternative_t<0, Tsu>
>
auto visit(
F&& f,
TSU&& tsu
) -> dlib::invoke_result_t<F, decltype(tsu.template cast_to<T0>())>
{
using ReturnType = dlib::invoke_result_t<F, decltype(tsu.template cast_to<T0>())>;
static constexpr std::size_t Size = type_safe_union_size<Tsu>::value;
return detail::visit_impl<ReturnType>(std::forward<F>(f), std::forward<TSU>(tsu), dlib::make_index_sequence<Size>{});
}
namespace detail
{
struct serialize_helper
{
serialize_helper(std::ostream& out_) : out(out_) {}
template <typename T>
void operator() (const T& item) const
{
serialize(item, out);
}
std::ostream& out;
};
struct deserialize_helper
{
deserialize_helper(
std::istream& in_,
int index_
) : index(index_),
in(in_)
{}
template<typename T, typename TSU>
void operator()(in_place_tag<T>, TSU&& x)
{
if (index == x.template get_type_id<T>())
deserialize(x.template get<T>(), in);
}
const int index = -1;
std::istream& in;
};
} // namespace detail
template<typename... Types>
inline void serialize (
const type_safe_union<Types...>& item,
std::ostream& out
)
{
try
{
serialize(item.get_current_type_id(), out);
item.apply_to_contents(detail::serialize_helper(out));
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing an object of type type_safe_union");
}
}
template<typename... Types>
inline void deserialize (
type_safe_union<Types...>& item,
std::istream& in
)
{
try
{
int index = -1;
deserialize(index, in);
if (index == 0)
item.clear();
else if (index > 0 && index <= (int)sizeof...(Types))
for_each_type(detail::deserialize_helper(in, index), item);
else
throw serialization_error("bad index value. Should be in range [0,sizeof...(Types))");
}
catch(serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing an object of type type_safe_union");
}
}
#if __cplusplus >= 201703L
template<typename ...Base>
struct overloaded_helper : Base...
{
template<typename... T>
overloaded_helper(T&& ... t) : Base{std::forward<T>(t)}... {}
using Base::operator()...;
};
#else
template<typename Base, typename ... BaseRest>
struct overloaded_helper: Base, overloaded_helper<BaseRest...>
{
template<typename T, typename ... TRest>
overloaded_helper(T&& t, TRest&& ...trest) :
Base{std::forward<T>(t)},
overloaded_helper<BaseRest...>{std::forward<TRest>(trest)...}
{}
using Base::operator();
using overloaded_helper<BaseRest...>::operator();
};
template<typename Base>
struct overloaded_helper<Base> : Base
{
template<typename T>
overloaded_helper<Base>(T&& t) : Base{std::forward<T>(t)}
{}
using Base::operator();
};
#endif //__cplusplus >= 201703L
template<typename... T>
overloaded_helper<typename std::decay<T>::type...> overloaded(T&&... t)
{
return overloaded_helper<typename std::decay<T>::type...>{std::forward<T>(t)...};
}
}
#endif // DLIB_TYPE_SAFE_UNIOn_h_