/* -*-c++-*- */
/* osgEarth - Dynamic map generation toolkit for OpenSceneGraph
 * Copyright 2008-2013 Pelican Mapping
 * http://osgearth.org
 *
 * osgEarth is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>
 */
#ifndef OSGEARTH_THREADING_UTILS_H
#define OSGEARTH_THREADING_UTILS_H 1

#include <osgEarth/Common>
#include <OpenThreads/Condition>
#include <OpenThreads/Mutex>
#include <OpenThreads/ReentrantMutex>
#include <osg/ref_ptr>
#include <set>
#include <map>

#define USE_CUSTOM_READ_WRITE_LOCK 1
//#ifdef _DEBUG
//#  define TRACE_THREADS 1
//#endif

namespace osgEarth { namespace Threading
{   
    typedef OpenThreads::Mutex Mutex;
    typedef OpenThreads::ScopedLock<OpenThreads::Mutex> ScopedMutexLock;
    typedef OpenThreads::Thread Thread;

    /**
     * Gets the unique ID of the running thread. Use this instead of
     * OpenThreads::Thread::CurrentThread, which only works reliably on
     * threads created with the OpenThreads framework
     */
    extern OSGEARTH_EXPORT unsigned getCurrentThreadId();


#ifdef USE_CUSTOM_READ_WRITE_LOCK

    /**
     * Event with a toggled signal state.
     */
    class Event 
    {
    public:
        Event() : _set( false ) { }

        ~Event() { 
            reset(); 
            for( int i=0; i<255; ++i ) // workaround buggy broadcast
                _cond.signal();
        }

        inline bool wait() {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _m );
            return _set ? true : (_cond.wait( &_m ) == 0);
        }

        /** waits on a signal, and then automatically resets it before returning. */
        inline bool waitAndReset() {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _m );
            if ( _set ) {
                _set = false;
                return true;
            }
            else {
                bool value = _cond.wait( &_m ) == 0;
                _set = false;
                return value;
            }
        }

        inline void set() {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _m );
            if ( !_set ) {
                _set = true;
                _cond.broadcast(); // possible deadlock before OSG r10457 on windows
                //_cond.signal();
            }
        }

        inline void reset() {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _m );
            _set = false;
        }

        inline bool isSet() const {
            return _set;
        }

    protected:
        OpenThreads::Mutex _m;
        OpenThreads::Condition _cond;
        bool _set;
    };

    /** Same as an Event, but waits on multiple notifications before releasing its wait. */
    class MultiEvent 
    {
    public:
        MultiEvent( int num =1 ) : _set( num ), _num(num)  { }

        ~MultiEvent() {
            reset();
            for( int i=0; i<255; ++i ) // workaround buggy broadcast
                _cond.signal();
        }

        inline bool wait() {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _m );
            while( _set > 0 )
                if ( _cond.wait( &_m ) != 0 )
                    return false;
            return true;
        }

        /** waits on a signal, and then automatically resets it before returning. */
        inline bool waitAndReset() {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _m );
            while( _set > 0 )
                if ( _cond.wait( &_m ) != 0 )
                    return false;
            _set = _num;
            return true;
        }

        inline void notify() {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _m );
            if ( _set > 0 )
                --_set;
            if ( _set == 0 )
                _cond.broadcast(); // possible deadlock before OSG r10457 on windows
            //_cond.signal();
        }

        inline void reset( int num =0 ) {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _m );
            if ( num > 0 ) _num = num;
            _set = _num;
        }

    protected:
        OpenThreads::Mutex _m;
        OpenThreads::Condition _cond;
        int _set, _num;
    };

    /**
     * Custom read/write lock. The read/write lock in OSG can unlock mutexes from a different
     * thread than the one that locked them - this can hang the thread in Windows.
     *
     * Adapted from:
     * http://www.codeproject.com/KB/threads/ReadWriteLock.aspx
     */
    class ReadWriteMutex
    {
#ifdef TRACE_THREADS
        typedef std::set<unsigned> TracedThreads;
        TracedThreads _trace;
        OpenThreads::Mutex _traceMutex;
#endif

    public:
        ReadWriteMutex() :
          _readerCount(0)
        { 
            _noWriterEvent.set();
            _noReadersEvent.set();
        }

        void readLock()
        {

#ifdef TRACE_THREADS
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> ttLock(_traceMutex);
                if( _trace.find(getCurrentThreadId()) != _trace.end() )
                    OE_WARN << "TRACE: tried to double-lock" << std::endl;
            }
#endif
            for( ; ; )
            {
                _noWriterEvent.wait();             // wait for a writer to quit if there is one
                incrementReaderCount();            // register this reader
                if ( !_noWriterEvent.isSet() )     // double lock check, in case a writer snuck in while inrementing
                    decrementReaderCount();        // if it did, undo the registration and try again
                else
                    break;                         // otherwise, we're in
            }

#ifdef TRACE_THREADS
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> ttLock(_traceMutex);
                _trace.insert(getCurrentThreadId());
            }
#endif
        }

        void readUnlock()
        {
            decrementReaderCount();                // unregister this reader
            
#ifdef TRACE_THREADS
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> ttLock(_traceMutex);
                _trace.erase(getCurrentThreadId());
            }
#endif
        }

        void writeLock()
        {
#ifdef TRACE_THREADS
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> ttLock(_traceMutex);
                if( _trace.find(getCurrentThreadId()) != _trace.end() )
                    OE_WARN << "TRACE: tried to double-lock" << std::endl;
            }
#endif
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _lockWriterMutex ); // one at a time please
            _noWriterEvent.wait();    // wait for a writer to quit if there is one
            _noWriterEvent.reset();   // prevent further writers from joining
            _noReadersEvent.wait();   // wait for all readers to quit

#ifdef TRACE_THREADS
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> ttLock(_traceMutex);
                _trace.insert(getCurrentThreadId());
            }
#endif
        }

        void writeUnlock()
        {
            _noWriterEvent.set();

#ifdef TRACE_THREADS
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> ttLock(_traceMutex);
                _trace.erase(getCurrentThreadId());
            }
#endif
        }

    protected:

        void incrementReaderCount()
        {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _readerCountMutex );
            _readerCount++;            // add a reader
            _noReadersEvent.reset();   // there's at least one reader now so clear the flag
        }

        void decrementReaderCount()
        {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock( _readerCountMutex );
            _readerCount--;               // remove a reader
            if ( _readerCount <= 0 )      // if that was the last one, signal that writers are now allowed
                _noReadersEvent.set();
        }

    private:
        int    _readerCount;
        Mutex  _lockWriterMutex;
        Mutex  _readerCountMutex;
        Event  _noWriterEvent;
        Event  _noReadersEvent;
    };


    struct ScopedWriteLock
    {
        ScopedWriteLock( ReadWriteMutex& lock ) : _lock(lock) { _lock.writeLock(); }
        ~ScopedWriteLock() { _lock.writeUnlock(); }
    protected:
        ReadWriteMutex& _lock;
    };

    struct ScopedReadLock
    {
        ScopedReadLock( ReadWriteMutex& lock ) : _lock(lock) { _lock.readLock(); }
        ~ScopedReadLock() { _lock.readUnlock(); }
    protected:
        ReadWriteMutex& _lock;
    };

#else

    typedef OpenThreads::ReadWriteMutex  ReadWriteMutex;
    typedef OpenThreads::ScopedWriteLock ScopedWriteLock;
    typedef OpenThreads::ScopedReadLock  ScopedReadLock;

#endif

    /** Template for per-thread data storage */
    template<typename T>
    struct PerThread
    {
        T& get() {
            ScopedMutexLock lock(_mutex);
            return _data[getCurrentThreadId()];
        }
        //const T& get() const {
        //    ScopedMutexLock lock(_mutex);
        //    return _data[getCurrentThreadId()];
        //}
    private:
        std::map<unsigned,T> _data;
        Mutex                _mutex;
    };

    /** Template for thread safe per-object data storage */
    template<typename KEY, typename DATA>
    struct PerObjectMap
    {
        DATA& get(KEY k)
        {
            {
                osgEarth::Threading::ScopedReadLock readLock(_mutex);
                typename std::map<KEY,DATA>::iterator i = _data.find(k);
                if ( i != _data.end() )
                    return i->second;
            }
            {
                osgEarth::Threading::ScopedWriteLock lock(_mutex);
                typename std::map<KEY,DATA>::iterator i = _data.find(k);
                if ( i != _data.end() )
                    return i->second;
                else
                    return _data[k];
            }
        }

        void remove(KEY k)
        {
            osgEarth::Threading::ScopedWriteLock exclusive(_mutex);
            _data.erase( k );
        }

    private:
        std::map<KEY,DATA>                  _data;
        osgEarth::Threading::ReadWriteMutex _mutex;
    };

    /** Template for thread safe per-object data storage */
    template<typename KEY, typename DATA>
    struct PerObjectRefMap
    {
        DATA* get(KEY k)
        {
            osgEarth::Threading::ScopedReadLock lock(_mutex);
            typename std::map<KEY,osg::ref_ptr<DATA > >::const_iterator i = _data.find(k);
            if ( i != _data.end() )
                return i->second.get();

            return 0L;
        }

        DATA* getOrCreate(KEY k, DATA* newDataIfNeeded)
        {
            osg::ref_ptr<DATA> _refReleaser = newDataIfNeeded;
            {
                osgEarth::Threading::ScopedReadLock lock(_mutex);
                typename std::map<KEY,osg::ref_ptr<DATA> >::const_iterator i = _data.find(k);
                if ( i != _data.end() )
                    return i->second.get();
            }

            {
                osgEarth::Threading::ScopedWriteLock lock(_mutex);
                typename std::map<KEY,osg::ref_ptr<DATA> >::iterator i = _data.find(k);
                if ( i != _data.end() )
                    return i->second.get();

                _data[k] = newDataIfNeeded;
                return newDataIfNeeded;
            }
        }

        void remove(KEY k)
        {
            osgEarth::Threading::ScopedWriteLock exclusive(_mutex);
            _data.erase( k );
        }

        void remove(DATA* data)
        {
            osgEarth::Threading::ScopedWriteLock exclusive(_mutex);
            for( typename std::map<KEY,osg::ref_ptr<DATA> >::iterator i = _data.begin(); i != _data.end(); ++i )
            {
                if ( i->second.get() == data )
                {
                    _data.erase( i );
                    break;
                }
            }
        }

    private:
        std::map<KEY,osg::ref_ptr<DATA> >    _data;
        osgEarth::Threading::ReadWriteMutex  _mutex;
    };

    /** Template for thread safe per-object data storage */
    template<typename KEY, typename DATA>
    struct PerObjectObsMap
    {
        DATA* get(KEY k)
        {
            osgEarth::Threading::ScopedReadLock lock(_mutex);
            typename std::map<KEY, osg::observer_ptr<DATA> >::const_iterator i = _data.find(k);
            if ( i != _data.end() )
                return i->second.get();

            return 0L;
        }

        DATA* getOrCreate(KEY k, DATA* newDataIfNeeded)
        {
            osg::ref_ptr<DATA> _refReleaser = newDataIfNeeded;
            {
                osgEarth::Threading::ScopedReadLock lock(_mutex);
                typename std::map<KEY,osg::observer_ptr<DATA> >::const_iterator i = _data.find(k);
                if ( i != _data.end() )
                    return i->second.get();
            }

            {
                osgEarth::Threading::ScopedWriteLock lock(_mutex);
                typename std::map<KEY,osg::observer_ptr<DATA> >::iterator i = _data.find(k);
                if ( i != _data.end() )
                    return i->second.get();

                _data[k] = newDataIfNeeded;
                return newDataIfNeeded;
            }
        }

        void remove(KEY k)
        {
            osgEarth::Threading::ScopedWriteLock exclusive(_mutex);
            _data.erase( k );
        }

        void remove(DATA* data)
        {
            osgEarth::Threading::ScopedWriteLock exclusive(_mutex);
            for( typename std::map<KEY,osg::observer_ptr<DATA> >::iterator i = _data.begin(); i != _data.end(); ++i )
            {
                if ( i->second.get() == data )
                {
                    _data.erase( i );
                    break;
                }
            }
        }

    private:
        std::map<KEY,osg::observer_ptr<DATA> >    _data;
        osgEarth::Threading::ReadWriteMutex       _mutex;
    };


    /** Template for thread safe observer set */
    template<typename T>
    struct ThreadSafeObserverSet
    {
        typedef void (*Functor)(T*);
        typedef void (*ConstFunctor)(const T*);

        void iterate( const Functor& f )
        {
            osgEarth::Threading::ScopedWriteLock lock(_mutex);
            for( typename std::set<T>::iterator i = _data.begin(); i != _data.end(); )
            {
                if ( i->valid() )
                {
                    f( i->get() );
                    ++i;
                }
                else
                {
                    i = _data.erase( i );
                }
            }
        }

        void iterate( const ConstFunctor& f )
        {
            osgEarth::Threading::ScopedReadLock lock(_mutex);
            for( typename std::set<T>::iterator i = _data.begin(); i != _data.end(); )
            {
                if ( i->valid() )
                {
                    f( i->get() );
                    ++i;
                }
                else
                {
                    i = _data.erase( i );
                }
            }
        }

        void insert(T* obj)
        {
            osgEarth::Threading::ScopedWriteLock lock(_mutex);
            _data.insert( obj );
        }

        void remove(T* obj)
        {
            osgEarth::Threading::ScopedWriteLock lock(_mutex);
            _data.erase( obj );
        }

    private:
        std::set<osg::observer_ptr<T> >      _data;
        osgEarth::Threading::ReadWriteMutex  _mutex;
    };

} } // namepsace osgEarth::Threading


#endif // OSGEARTH_THREADING_UTILS_H

