簡易 variant を作ってみた

使い方は Boost.Variant とほとんど一緒だけど、内部実装は全然違うやつ。
Boost.Variant は、テンプレート引数として渡されたデータを全部保持できる共有の領域を用意して、それを Boost.Variant オブジェクトの一部として構築する(つまり boost::variant をスタック上に置けばスタック上に共有の領域があることになる)。これは時に Boost.Variant の assign の処理 - melpon日記 - HaskellもC++もまともに扱えないへたれのページ みたいな複雑な処理を行わないといけなかったりすることがあったりする。
それに対してこの variant は共有の領域なんて用意せずに Boost.Any みたいな何でも入るクラスに適当に設定してるだけなので、代入とかコピー操作を行うたびに new しててめちゃめちゃ効率悪かったりするんだけど、おかげで造りが簡単になってたり。


以下実装


variant.hpp

#ifndef MTL_VARIANT_HPP_INCLUDED
#define MTL_VARIANT_HPP_INCLUDED

#include <cassert>
#include <exception>
#include <utility>

namespace mtl
{
    namespace variant_detail
    {
        // boost::any みたいなクラス
        // variant クラスは必ず何らかの値を持つから ptr_ == 0 という状態は無いので、実は少しだけ最適化できる。
        // というか invalid な状態を許可する variant クラスを作ってもいいのかもしれない。
        class any_value
        {
        public:
            any_value() : ptr_(0)
            {
            }

            any_value(const any_value& v)
                : ptr_(v.ptr_ ? v.op_->clone(v.ptr_) : 0)
                , op_(v.op_)
            {
            }

            template<typename T>
            any_value(const T& v)
                : ptr_(reinterpret_cast<void*>(new T(v)))
                , op_(get_operand<T>())
            {
            }

            ~any_value()
            {
                if (ptr_) op_->destroy(ptr_);
            }

            template<typename T>
            any_value& operator=(const T& rhs)
            {
                any(rhs).swap(*this);
                return *this;
            }

            any_value& operator=(any_value rhs)
            {
                rhs.swap(*this);
                return *this;
            }

            any_value& swap(any_value& rhs)
            {
                std::swap(ptr_, rhs.ptr_);
                std::swap(op_, rhs.op_);
                return *this;
            }

            template<class T>
            T* unsafe_get() const
            {
                return reinterpret_cast<T*>(ptr_);
            }

        private:
            struct operand
            {
                void* (*clone)(void*);
                void (*destroy)(void*);
            };
            template<class T>
            static void* clone_impl(void* p)
            {
                return reinterpret_cast<void*>(new T(*reinterpret_cast<T*>(p)));
            }
            template<class T>
            static void destroy_impl(void* p)
            {
                delete reinterpret_cast<T*>(p);
            }
            template<class T>
            static operand* get_operand()
            {
                static operand op =
                {
                    &any_value::clone_impl<T>,
                    &any_value::destroy_impl<T>,
                };
                return &op;
            }

            void* ptr_;
            operand* op_;
        };

        // for get_visitor
        template<class T>
        T* addressof(T& v)
        {
            return reinterpret_cast<T*>(&const_cast<char&>(reinterpret_cast<const volatile char &>(v)));
        }

        template<class T>
        struct get_visitor
        {
            typedef T* result_type;

            result_type operator()(T& v)
            {
                return addressof(v);
            }

            template<class U>
            result_type operator()(U&)
            {
                return 0;
            }
        };

        template<bool B> struct bool_ { };
    }

#define MTL_VARIANT_CAT_II(a,b) a##b
#define MTL_VARIANT_CAT_I(a,b) MTL_VARIANT_CAT_II(a,b)
#define MTL_VARIANT_CAT(a,b) MTL_VARIANT_CAT_I(a,b)

// 最大数を増やしたいときはこの辺弄ればOK。ちょっと面倒だけど。
    namespace variant_detail
    {
        struct void1_ { };
        struct void2_ { };
        struct void3_ { };
        struct void4_ { };
        struct void5_ { };
        struct void6_ { };
        struct void7_ { };
        struct void8_ { };
        template<class T> struct is_void { static const bool value = false; };
        template<> struct is_void<void1_> { static const bool value = true; };
        template<> struct is_void<void2_> { static const bool value = true; };
        template<> struct is_void<void3_> { static const bool value = true; };
        template<> struct is_void<void4_> { static const bool value = true; };
        template<> struct is_void<void5_> { static const bool value = true; };
        template<> struct is_void<void6_> { static const bool value = true; };
        template<> struct is_void<void7_> { static const bool value = true; };
        template<> struct is_void<void8_> { static const bool value = true; };
    }

#define MTL_VARIANT_TEMPLATE_PARMS_FORWARD \
    class T0, \
    class T1 = mtl::variant_detail::void1_, \
    class T2 = mtl::variant_detail::void2_, \
    class T3 = mtl::variant_detail::void3_, \
    class T4 = mtl::variant_detail::void4_, \
    class T5 = mtl::variant_detail::void5_, \
    class T6 = mtl::variant_detail::void6_, \
    class T7 = mtl::variant_detail::void7_, \
    class T8 = mtl::variant_detail::void8_ \
    /**/

#define MTL_VARIANT_TEMPLATE_PARMS \
    class T0, \
    class T1, \
    class T2, \
    class T3, \
    class T4, \
    class T5, \
    class T6, \
    class T7, \
    class T8 \
    /**/

#define MTL_VARIANT_TEMPLATE_ARGS       T0, T1, T2, T3, T4, T5, T6, T7, T8

#define MTL_VARIANT_MAKE_COPY_ASSIGN(N) \
    variant(const MTL_VARIANT_CAT(T,N)& t) : value_(t), which_(N) { } \
    variant& operator=(const MTL_VARIANT_CAT(T,N)& t) { variant(t).swap(*this); return *this; }

#define MTL_VARIANT_COPY_ASSIGN \
    MTL_VARIANT_MAKE_COPY_ASSIGN(0) \
    MTL_VARIANT_MAKE_COPY_ASSIGN(1) \
    MTL_VARIANT_MAKE_COPY_ASSIGN(2) \
    MTL_VARIANT_MAKE_COPY_ASSIGN(3) \
    MTL_VARIANT_MAKE_COPY_ASSIGN(4) \
    MTL_VARIANT_MAKE_COPY_ASSIGN(5) \
    MTL_VARIANT_MAKE_COPY_ASSIGN(6) \
    MTL_VARIANT_MAKE_COPY_ASSIGN(7) \
    MTL_VARIANT_MAKE_COPY_ASSIGN(8) \
    /**/

#define MTL_VARIANT_MAKE_WHICH_CASE(N) \
    case N: return apply(f, *value_.unsafe_get<MTL_VARIANT_CAT(T,N)>(), mtl::variant_detail::bool_<mtl::variant_detail::is_void<MTL_VARIANT_CAT(T,N)>::value>());

#define MTL_VARIANT_WHICH_CASE \
    MTL_VARIANT_MAKE_WHICH_CASE(0) \
    MTL_VARIANT_MAKE_WHICH_CASE(1) \
    MTL_VARIANT_MAKE_WHICH_CASE(2) \
    MTL_VARIANT_MAKE_WHICH_CASE(3) \
    MTL_VARIANT_MAKE_WHICH_CASE(4) \
    MTL_VARIANT_MAKE_WHICH_CASE(5) \
    MTL_VARIANT_MAKE_WHICH_CASE(6) \
    MTL_VARIANT_MAKE_WHICH_CASE(7) \
    MTL_VARIANT_MAKE_WHICH_CASE(8) \
    /**/

    // forward declaration
    template<MTL_VARIANT_TEMPLATE_PARMS_FORWARD>
    struct variant;

    template<MTL_VARIANT_TEMPLATE_PARMS>
    struct variant
    {
        variant() : value_(T0()), which_(0)
        {
        }

        variant(const variant& v) : value_(v.value_), which_(v.which_)
        {
        }

        variant& operator=(const variant& v)
        {
            variant(v).swap(*this);
            return *this;
        }

        MTL_VARIANT_COPY_ASSIGN

        int which() const { return which_; }

        void swap(variant& v)
        {
            value_.swap(v.value_);
            std::swap(which_, v.which_);
        }

        template<class F>
        typename F::result_type apply_visitor(F f) const
        {
            switch (which_)
            {
            MTL_VARIANT_WHICH_CASE
            }

            return noreturn<typename F::result_type>();
        }

        // operator== や operator< ぐらいは実装してもいいかもしれない

    private:
        template<class R>
        static R noreturn()
        {
            // ここに来ることは無い
            assert(false);
            // どうせここは通らないので何でもいい
            // 例外投げると最適化に影響しそうだから noreturn 属性とか使えるといいのだけど、
            // ポータブルじゃないから諦める
            throw 0;
        }

        // これは apply_visitor の中のマクロから呼ばれる。
        // T が voidN_ 型の場合は何もしないようにしないとコンパイルエラーになってしまう。
        template<class F, class T>
        static typename F::result_type apply(F f, T& t, mtl::variant_detail::bool_<false>)
        {
            return f(t);
        }
        template<class F, class T>
        static typename F::result_type apply(F f, T& t, mtl::variant_detail::bool_<true>)
        {
            return noreturn<typename F::result_type>();
        }

        variant_detail::any_value value_;
        int which_;
    };

    template<MTL_VARIANT_TEMPLATE_PARMS>
    void swap(variant<MTL_VARIANT_TEMPLATE_ARGS>& lhs, variant<MTL_VARIANT_TEMPLATE_ARGS>& rhs)
    {
        lhs.swap(rhs);
    }

    template<class U, MTL_VARIANT_TEMPLATE_PARMS>
    U* get(const variant<MTL_VARIANT_TEMPLATE_ARGS>* v)
    {
        if (!v) return 0;
        return v->apply_visitor(variant_detail::get_visitor<U>());
    }

    struct bad_get : std::exception
    {
        virtual const char* what() const throw() { return "bad_get"; }
    };

    template<class U, MTL_VARIANT_TEMPLATE_PARMS>
    U& get(const variant<MTL_VARIANT_TEMPLATE_ARGS>& v)
    {
        U* p = v.apply_visitor(variant_detail::get_visitor<U>());
        if (!p) throw bad_get();
        return *p;
    }

// cleanup
#undef MTL_VARIANT_CAT_II
#undef MTL_VARIANT_CAT_I
#undef MTL_VARIANT_CAT

#undef MTL_VARIANT_TEMPLATE_PARMS_FORWARD
#undef MTL_VARIANT_TEMPLATE_PARMS
#undef MTL_VARIANT_TEMPLATE_ARGS

#undef MTL_VARIANT_MAKE_COPY_ASSIGN
#undef MTL_VARIANT_COPY_ASSIGN

#undef MTL_VARIANT_MAKE_WHICH_CASE
#undef MTL_VARIANT_WHICH_CASE

}

#endif // MTL_VARIANT_HPP_INCLUDED

簡易テスト

#include "variant.hpp"
#include <cassert>
#include <string>
#include <iostream>

// 簡単なテスト

void test()
{
    using namespace mtl;
    variant<int, std::string> v;
    assert(v.which() == 0);
    assert(get<int>(v) == 0);
    assert(*get<int>(&v) == 0);
    assert(get<std::string>(&v) == 0);

    variant<int, std::string> v2(10);
    assert(v2.which() == 0);
    assert(get<int>(v2) == 10);

    variant<int, std::string> v3("hoge");
    assert(v3.which() == 1);
    assert(get<std::string>(v3) == "hoge");

    v = v2;
    assert(v.which() == 0);
    assert(get<int>(v) == 10);

    v2 = v3;
    assert(v2.which() == 1);
    assert(get<std::string>(v2) == "hoge");

    v3 = 100;
    assert(v3.which() == 0);
    assert(get<int>(v3) == 100);
}

// 稲葉さんのところの例をちょっと弄ってみたやつ

// 「2倍する」Visitor
struct do_double
{
    typedef void result_type;

    template<typename T>
    void operator()( T& t ) const { t = t + t; }
};

void test2()
{
    mtl::variant<int,double,std::string> v;
    v = -2;
    assert( v.which() == 0 );           // 中身がintなことを確認
    std::cout << mtl::get<int>(v) << std::endl; // intを取り出し

    v = 3.14;
    assert( v.which() == 1 ); // 中身がdoubleなことを確認
    std::cout << mtl::get<double>(v) << std::endl;        // operator<< の visitor に対応してないので get は省略できない

    v = "hoge";
    assert( v.which() == 2 );        // 中身がstringなことを確認
    v.apply_visitor( do_double() ); // Visitorを投げる
    std::cout << mtl::get<std::string>(v) << std::endl;
}

int main()
{
    test();
    test2();
}