Java で Boost.Coroutine 的なクラスを作ってみた

以前に BREW で Boost.Coroutine を使えるようにした関係で、それを Java に移植するのが難しくなってしまったので、Java で Coroutine っぽく使えるようなクラスを作ってみました。


当然スタック切り替えなんてできるわけが無いので、内部的にはスレッドを使っています。
Coroutine#coro() が呼び出されたらスレッドを動かし、yield() されるか終了するまで wait() するだけです。
とりあえずの実装ということで、Boost.Coroutine の coroutine に相当するクラスしか作ってません。
あと CoroutineFunction#coro() の中での例外は全て握りつぶされます。もちろん例外有りバージョンの呼び出しなんて未実装です。

public interface Disposable
{
    void dispose();
}
public interface CoroutineFunction
{
    // どんな例外も許可する。
    void coro(Coroutine.Self self)
        throws Exception;
}
final public class CoroutineInterruptedException extends RuntimeException
{
    // ユーザが生成することはできないようにしておく
    CoroutineInterruptedException() { }
    CoroutineInterruptedException(String s) { super(s); }
}
final public class Coroutine implements Disposable
{
    private Thread th;
    final private CoroutineFunction func;
    private Self self;
    private int state;
    
    public Coroutine(CoroutineFunction func)
    {
        this.func = func;
    }
    
    final static public class Self
    {
        final private Coroutine coro;
        
        Self(Coroutine coro)
        {
            this.coro = coro;
        }
        
        public void yield()
        {
            synchronized (coro)
            {
                coro.notifyState(1);
                while (coro.state == 1)
                {
                    try
                    {
                        coro.wait();
                    }
                    catch (InterruptedException e)
                    {
                        throw new CoroutineInterruptedException();
                    }
                }
            }
        }
        
        public void exit()
        {
            throw new CoroutineInterruptedException();
        }
        
        void run()
        {
            try
            {
                coro.func.coro(this);
                yield();
            }
            catch (CoroutineInterruptedException e)
            {
                // ここは正常な終了
            }
            catch (Exception e)
            {
                // ほんとはこの e を保持しておいて Coroutine#coro() を呼び出した側に
                // 例外が発生したことを伝えるべきなんだけど、面倒なので必要になったら実装する
            }
            finally
            {
                coro.notifyState(2);
            }
        }
        
    }
    
    // Runnable を外部に公開したくないのでこのクラス経由で呼び出す
    final private class Trampoline implements Runnable
    {
        final private Coroutine.Self self;
        Trampoline(Coroutine.Self self)
        {
            this.self = self;
        }
        public void run()
        {
            self.run();
        }
    }
    
    public void coro()
    {
        if (!isAlive())
        {
            return;
        }
        
        synchronized (this)
        {
            notifyState(0);
            
            if (th == null)
            {
                self = new Self(this);
                th = new Thread(new Trampoline(self));
                th.start();
            }
            
            while (state == 0)
            {
                waitNothrow(this);
            }
        }
    }
    
    public boolean isAlive()
    {
        return getState() != 2;
    }
    
    public void exit()
    {
        if (!isAlive())
        {
            return;
        }
        
        th.interrupt();
        synchronized (this)
        {
            while (state != 2)
            {
                waitNothrow(this);
            }
        }
    }
    
    public void dispose()
    {
        exit();
        if (func instanceof Disposable)
        {
            ((Disposable)func).dispose();
        }
    }
    
    private static void waitNothrow(Object obj)
    {
        try
        {
            obj.wait();
        }
        catch (InterruptedException e)
        {
            // ここに来ることは無いはず
            // とりあえず例外を投げておく
            throw new IllegalThreadStateException();
        }
    }
    
    private void notifyState(int n)
    {
        synchronized (this)
        {
            state = n;
            notifyAll();
        }
    }
    
    private int getState()
    {
        synchronized (this)
        {
            return state;
        }
    }
}
0
1
2
3
4
5
6
7
8
9