Source

libscala / stm.scala

Full commit
package ls

import java.lang.ref.SoftReference
import java.util.concurrent.atomic.{AtomicReference, AtomicLong}
import java.util.concurrent.locks.ReentrantLock
import scala.collection.immutable.{LongMap, TreeSet}

object Stm extends Stm

class Stm {
    abstract class BaseRef{val id: Long} // non-parameterized base class.

    val refs = new AtomicReference(LongMap[Any]())
    val idGenerator = new AtomicLong(0) //a unique id assigned to each TVar.
    val transactions = new ThreadLocal[Transaction] { override def initialValue() = new Transaction() }

    class Ref[T](initialValue: T) extends BaseRef() {
        val id: Long = idGenerator.incrementAndGet
        //val hashCode = id.isInstanceOf[Int]
        private def register() = {
            var tryAgain  = false
            do {
                val cur = refs.get
                tryAgain = !refs.compareAndSet(cur, cur.updated(id, initialValue))
            } while(tryAgain)
        }
        register()
        
        def trans = transactions.get
        def peek: T = trans.peek(this).asInstanceOf[T]
        def apply(): T = trans.get(this).asInstanceOf[T]
        def update(v: T) = trans.set(this, v)

        override def finalize = {
            var tryAgain = false
            do { // remove id from list of refs - atomically.
                val cur = refs.get
                tryAgain = !refs.compareAndSet(cur, cur - id)
            } while(tryAgain)
        }
    }
    
    object TVar { def apply[T](initialValue: T) = new Ref[T](initialValue) }
    
    def commit() = transactions.get.commit()
    def abort() = transactions.remove()
    def later(func: () => Unit) = transactions.get.later(func)
    
    /**
      * Wrap the <b>side-effect-free function</b> 'func' in a transaction. Retry until successful.
      */
    def apply[T](func: => T) = {
        try {
            var result = func
            var success = commit()
            var retries = 0;
            while(!success) {
                // Allow other transactions to progress with "exponential" backoff.
                Thread.sleep(0, 1)
                result = func
                try { success = commit() }
                catch { case e: RetryTransaction => success = false } //strictly unecessary.
                retries += 1;
                if (retries > 1000) throw new Exception("Transaction Failure: " + retries)
            }
            result
        } finally abort()
    }

    class RetryTransaction(s: String) extends RuntimeException(s)
    object RetryTransaction extends RetryTransaction("") {
        def apply() = throw this
        def apply(s:String) = throw new RetryTransaction(s)
    }

    class Transaction {
        val initial = refs.get
        var modified = LongMap[Any]() // why not HashMap[Long, Any]?  We need the fast mergeability.
        val touched = collection.JavaConversions.asSet(new java.util.HashSet[BaseRef](4))
        val todoLater = new collection.mutable.ArrayBuffer[() => Unit]()

        private def unregisteredRef = RetryTransaction("Ref not registered when transaction started")
        def peek(ref: BaseRef) = modified.getOrElse(ref.id, initial.getOrElse(ref.id, unregisteredRef))
        def get(ref: BaseRef) = { touched.add(ref); peek(ref) }
        def set(ref: BaseRef, value: Any) = modified = modified.updated(ref.id, value)

        def commit() = {
            val success = if (modified.size == 0) true else {
                var tryAgain, success = false
                do {
                    // To prevent write skew and guarantee serializability:
                    var changed = false
                    val touchedIterator = touched.iterator
                    val current = refs.get //from parent.
                    while (touchedIterator.hasNext && !changed) {
                        val id = touchedIterator.next.id
                        if (current(id) != initial.getOrElse(id, unregisteredRef)) changed = true
                    }
                    tryAgain = if (changed) false else {
                        success = refs.compareAndSet(current, current ++ modified); !success
                    }
                } while(tryAgain)
                success
            }
            transactions.remove()
            if (success) for (action <- todoLater) action()
            success
        }
        def later(item: () => Unit) { todoLater += item }
    }

    trait CacheT[K, V] {
        def get(key:K): Option[V]
        def apply(key:K): V
        def update(key:K, value:V)
        def updateMany(updates: Iterable[(K,V)])
        def clear(): Unit
    }

    class Cache[K, V] extends CacheT[K, V] {
        val map = TVar(Map[K,SoftReference[V]]())

        private def get(key:K, _map: Map[K,SoftReference[V]]): Option[V] = {
            _map.get(key) match {
                case None => None
                case Some(ref) => {
                    val value = ref.get()
                    if (value != null) Some(value) else { map() = _map - key; None}
                }
            }
        }

        def get(key:K): Option[V] = get(key, map())
        
        def apply(key:K): V = {
            get(key) match {
               case None => throw new NoSuchElementException(key.toString)
               case Some(value) => value
            }
        }

        private def update(key:K, value:V, _map: Map[K,SoftReference[V]]) = {
            val _map = map();
            val needsUpdate = get(key, _map) match {
                case None => true
                case Some(oldVal) => (oldVal != value)                
            }
            if (needsUpdate) map() = _map.updated(key, new SoftReference(value))
        }
        def update(key:K, value:V) = update(key, value, map())

        def updateMany(updates: Iterable[(K,V)]) {
            val refs = for (e <- updates) yield (e._1, new SoftReference(e._2))
            map()= map() ++ refs
        }

        def clear() = map() = Map[K,SoftReference[V]]()

        val lock = new java.util.concurrent.locks.ReentrantLock //keep thundering herds at bay
        def getOrFetch(key: K, func: => Option[V]) {
            val _map = map()
            get(key, _map) match {
                case Some(value) => Some(value)
                case None => {
                    if (lock.tryLock) {
                        try func match {
                            case Some(value) => { update(key, value); Some(value) }
                            case None => None
                        } finally lock.unlock
                    } else { lock.lock; lock.unlock; RetryTransaction("cache locked => cache miss")}
                }
            }
        }
    }

    class ConcurrentCache[K, V](concurrency: Int) extends CacheT[K,V]{
        def this() = this(256)
        val caches = (for (i <- 0 until concurrency) yield new Cache[K, V]()).toArray
        def index(key: K) = { key.hashCode % concurrency }
        def get(key:K): Option[V] = caches(index(key)).get(key)
        def apply(key:K): V = caches(index(key))(key)
        def update(key:K, value:V) = caches(index(key)).update(key, value)
        def updateMany(updates: Iterable[(K,V)]) {
            for (i <- 0 until concurrency) {
                val refs = for (e <- updates if index(e._1) == i) yield (e._1, e._2)
                caches(i).updateMany(refs)
            }
        }
        def clear() = for (cache <- caches) cache.clear
    }

    case class Item[K <% Ordered[K], ID <% Ordered[ID]](key: K, id: ID) extends Ordered[Item[K,ID]] {
        def compare(other: Item[K, ID]) = {
            if (key == other.key) id compare other.id else key compare other.key
        }
    }

    class IndexTree[K <% Ordered[K], ID <% Ordered[ID]](min: ID) {
        val tree = TVar(new TreeSet[Item[K, ID]]())
        def fromKey (key: K) = tree().from(Item(key, min))
        def forKey(key: K) = fromKey(key).takeWhile(_.key == key)
        def apply (key: K) = forKey(key).map(_.id)

        def update(entries: TraversableOnce[Item[K,ID]]) = tree() = tree() ++ entries
        def update(value: Item[K, ID]) = tree() = tree() + value
        def update(key: K, id: ID): Unit = update(Item(key, id))
        
        def -= (key: K) = { tree() = (tree() -- forKey(key)) }
        def clear() = tree() = TreeSet[Item[K,ID]]()

        def replace(oldval: Item[K, ID], current: Item[K, ID]) = {
            val _tree = tree()
            tree() = (_tree - oldval) + current
        }

        def from (key: K) = fromKey(key).map(_.id)
        def until (key: K) = untilKey(key).map(_.id)
        def untilKey (key: K) = tree().until(Item(key, min))
    }
    
    class IndexCache[K <% Ordered[K], ID <% Ordered[ID]] (min: ID)
                                                         (getter: () => TraversableOnce[Item[K,ID]]) {
        private val ref = TVar(new SoftReference[IndexTree[K,ID]](null))
        private val lock = new ReentrantLock()
        def apply() {
            val oldCache: IndexTree[K,ID] = ref().get
            if (oldCache != null) oldCache
            else if (lock.tryLock) {
                try {
                    val newCache = new IndexTree[K, ID](min)
                    newCache.update(getter())
                    ref()= new SoftReference(newCache)
                    newCache
                } finally lock.unlock
            } else { //avoids busy-waiting. ref().get should succeed next time.
                lock.lock; lock.unlock; RetryTransaction("cache entry was locked for update")
            }
        }
    }
}