Source

polify / src / main / scala / updown / util / TopicModel.scala

package updown.util

import updown.data.{SentimentLabel, GoldLabeledTweet}
import cc.mallet.types._

case class Topic(prior: Double, distribution: Map[String, Double])

abstract class TopicModel {
  protected def getInstanceList(tweetList: List[GoldLabeledTweet]): (Alphabet, InstanceList) = {
    val alphabet = new Alphabet()
    val labelAlphabet = new Alphabet()
    val instances = (for (tweet <- tweetList) yield {
      tweet match {
        case GoldLabeledTweet(id, userid, features, goldLabel) =>
          val featureSequence = new FeatureSequence(alphabet, features.length)
          for (feature <- features) {
            featureSequence.add(feature)
          }
          val label = new FeatureVector(
            labelAlphabet,
            Array[Object]("label"), Array[Double](SentimentLabel.toDouble(goldLabel)))
          new Instance(featureSequence, label, id, null)
      }
    }).toList

    val instanceList = new InstanceList(alphabet, null)
    for (instance <- instances) {
      instanceList.add(instance)
    }
    (alphabet, instanceList)
  }

  protected def getInstanceList(tweetList: List[GoldLabeledTweet], alphabet: Alphabet) = {
    val instances = (for (tweet <- tweetList) yield {
      tweet match {
        case GoldLabeledTweet(id, userid, features, goldLabel) =>
          val featureSequence = new FeatureSequence(alphabet, features.length)
          for (feature <- features) {
            featureSequence.add(feature)
          }
          new Instance(featureSequence, goldLabel, id, null)
      }
    }).toList

    val instanceList = new InstanceList(alphabet, null)
    for (instance <- instances) {
      instanceList.add(instance)
    }
    instanceList
  }

  def getTopics: List[Topic]

  def getTopicPriors: Array[Double]

  def getIdsToTopicDist: Map[String, Array[Double]]

  def getLabelsToTopicDists: Map[SentimentLabel.Type, List[Array[Double]]]

  def getLabelsToTopicDist: Map[SentimentLabel.Type, Array[Double]] = {
    (for ((label, topicDist: List[Array[Double]]) <- getLabelsToTopicDists) yield {
      val N = topicDist.length
      (label,
        topicDist
          .reduce((a: Array[Double], b: Array[Double]) => (a zip b).map {
          case (x, y) => x + y
        })
          .map(_ / N)
        )
    }).toMap
  }

  def inferTopics(tweet: GoldLabeledTweet): Array[Double]

  def save(filename: String)
}
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.