Source

libscala / tl2.scala

Full commit
package ls

import java.util.concurrent.locks.{ReentrantLock}
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}


object Tl2 extends Tl2 //usage: import ls.Tl2._

class Tl2 {
    private val clock = new AtomicLong(1)
    val transactions = new ThreadLocal[Transaction]
                               { override def initialValue() = new ReadWriteTransaction() }
    private def t = transactions.get

    class RetryTransaction(s: String) extends RuntimeException(s)
    object RetryTransaction extends RetryTransaction("") {
        def apply(s:String) = new RetryTransaction(s)
    }

    class TVar[E](initial: E) extends Comparable[TVar[E]]{
        case class Ref(clock: Long, value:E){ @volatile var owner: Transaction = null }

        private val ref = new AtomicReference(new Ref(Tl2.this.clock.get, initial))
        protected[Tl2] val lock = new ReentrantLock()

        def writeLock(t: Transaction) {
            lock.lock
            ref.get.owner = t
        }

        def writeUnLock(t: Transaction) {
            lock.unlock
            ref.get.owner = null
        }
        
        protected[Tl2] def clockCheck(t: Transaction, r: Ref) {
            if (r.clock > t.clock) throw RetryTransaction(" TVar updated by another Transaction ")
            if (r.owner != null && r.owner != t) throw RetryTransaction(" Locked by another Transaction ")
        }
        
        protected[Tl2] def clockCheck(t: Transaction): Unit = clockCheck (t, ref.get)

        protected[Tl2] def getCurrent(t: Transaction) = {
            val r: Ref = ref.get
            clockCheck(t, r); r.value
        }
        
        protected[Tl2] def setCurrent(newValue: Any, newClock: Long) = {
            val old = ref.get
            assert(newClock >= old.clock) // TL2 GV4 id sharing makes = possible?
            ref.set(new Ref(newClock, newValue.asInstanceOf[E]))
        }

        def apply(): E = t.get(this).asInstanceOf[E]
        def update(v: E) = t.set(this, v)

        val id: Long = idSource.incrementAndGet
        def compareTo(other: TVar[E]) = { val id2=other.id; if (id<id2) -1 else if (id>id2) 1 else 0 }
    }
    protected val idSource = new AtomicLong(1) //unique id for each TVar so

    def TVar[E](initial: E) = { new TVar[E](initial) }

    protected trait Transaction {
        var clock: Long = Tl2.this.clock.get
        
        def get[E](ref: TVar[E]): E
        def set[E](ref: TVar[E], value:E): Unit
        def prepare(): Unit
        def commit(): Unit
        def abort(): Unit
    }

    protected class ReadOnlyTransaction extends Transaction {
        def get[E](ref: TVar[E]) = { ref.getCurrent(this) }
        def set[E](ref: TVar[E], value:E): Unit = throw new RuntimeException("Read-only Transaction")
        def prepare() = {}
        def commit() = clock = Tl2.this.clock.get
        def abort() = commit()
    }

    protected class ReadWriteTransaction extends Transaction {
        val writes = new java.util.TreeMap[TVar[_], Any]()
        val reads = new java.util.HashSet[TVar[_]]()
        var begun = true

        private def begin() = {
            writes.clear; reads.clear;
            begun = true; prepared = false
            clock = Tl2.this.clock.get;
          }
        
        private def ifGood[T](func: =>T): T = {
             if (!begun) begin()
             if (!prepared) func else throw new RuntimeException("Prepared!")
        }

        def get[E](ref: TVar[E]): E = ifGood {
            val inWrite = writes.get(ref)
            if (inWrite != null) inWrite.asInstanceOf[E]
            else {
                reads.add(ref)
                ref.getCurrent(this)
            }
        }
            
        def set[E](ref: TVar[E], value: E) = ifGood { writes.put(ref, value.asInstanceOf[E]) }

        private var prepared = false
        private var newClock = 0L;

        def prepare() = if (begun) {
            //3. Lock the writeset
            val iter = writes.keySet.iterator //acquire write locks in sorted order
            while(iter.hasNext) iter.next.writeLock(this)
            prepared = true

            //4. Increment global version-clock: (TL2-GV4 optimization)
            newClock = Tl2.this.clock.get + 1;
            Tl2.this.clock.compareAndSet(newClock - 1, newClock);

            //5. Validate the read-set:
            try {
                val iter = reads.iterator
                while(iter.hasNext) iter.next.clockCheck(this)
            } catch { case e: RetryTransaction => {
                        //for (k <- writes.keysIterator) k.writeUnLock(this)
                        val iter = writes.keySet.iterator
                        while(iter.hasNext) iter.next.writeUnLock(this)
                        throw e; println("retry")
                    }
            }
        }

        def commit() = {
            if (!prepared) prepare()
            //6. Commit and release the locks:
            val iter = writes.entrySet.iterator
            while(iter.hasNext) {
                val next = iter.next; val k = next.getKey; val v = next.getValue
                k.setCurrent(v, newClock); k.lock.unlock
            }
            begun = false  //mark transaction as dead
        }

        def abort() = { begun = false }
    }

    def commit() = t.commit
    def prepare() = t.prepare
    def abort() = t.abort

    var allretries = 0
    def tx[T](func: => T): T = {
        var retries = 0;
        var committed = false
        val localt = t;
        do {
            try {
                val result = func
                localt.commit
                committed = true
                return result
            } catch {
                case e: RetryTransaction => { retries += 1; allretries += 1 }
            } finally if (!committed) localt.abort()
            Thread.sleep(1)
        } while(retries < 1000)
        throw new Exception("Transaction Failure: " + retries)
    }
    
    def withRO() = {
        if (!transactions.get.isInstanceOf[ReadOnlyTransaction])
           transactions.set(new ReadOnlyTransaction)
        this
    }

    def withRW() = {
        if (!transactions.get.isInstanceOf[ReadWriteTransaction])
           transactions.set(new ReadWriteTransaction)
        this
    }
}