Commits

David Mimno  committed e25b927

generics for topic words

  • Participants
  • Parent commits 02cf343

Comments (0)

Files changed (1)

File src/cc/mallet/topics/ParallelTopicModel.java

 	 *   contains IDSorter objects with integer keys into the alphabet.
 	 *   To get direct access to the Strings, use getTopWords().
 	 */
-	public TreeSet[] getSortedWords () {
+	public ArrayList<TreeSet<IDSorter>> getSortedWords () {
 	
-		TreeSet[] topicSortedWords = new TreeSet[ numTopics ];
+		ArrayList<TreeSet<IDSorter>> topicSortedWords = new ArrayList<TreeSet<IDSorter>>(numTopics);
 
 		// Initialize the tree sets
 		for (int topic = 0; topic < numTopics; topic++) {
-			topicSortedWords[topic] = new TreeSet<IDSorter>();
+			topicSortedWords.add(new TreeSet<IDSorter>());
 		}
 
 		// Collect counts
 				int topic = topicCounts[index] & topicMask;
 				int count = topicCounts[index] >> topicBits;
 
-				topicSortedWords[topic].add(new IDSorter(type, count));
+				topicSortedWords.get(topic).add(new IDSorter(type, count));
 
 				index++;
 			}
 	
 	public Object[][] getTopWords(int numWords) {
 
-		TreeSet[] topicSortedWords = getSortedWords();
+		ArrayList<TreeSet<IDSorter>> topicSortedWords = getSortedWords();
 		Object[][] result = new Object[ numTopics ][];
 
 		for (int topic = 0; topic < numTopics; topic++) {
 			
-			TreeSet<IDSorter> sortedWords = topicSortedWords[topic];
+			TreeSet<IDSorter> sortedWords = topicSortedWords.get(topic);
 			
 			// How many words should we report? Some topics may have fewer than
 			//  the default number of words with non-zero weight.
 
 		StringBuilder out = new StringBuilder();
 
-		TreeSet[] topicSortedWords = getSortedWords();
+		ArrayList<TreeSet<IDSorter>> topicSortedWords = getSortedWords();
 
 		// Print results for each topic
 		for (int topic = 0; topic < numTopics; topic++) {
-			TreeSet<IDSorter> sortedWords = topicSortedWords[topic];
+			TreeSet<IDSorter> sortedWords = topicSortedWords.get(topic);
 			int word = 1;
 			Iterator<IDSorter> iterator = sortedWords.iterator();
 
 	}
 	
 	public void topicXMLReport (PrintWriter out, int numWords) {
-		TreeSet[] topicSortedWords = getSortedWords();
+		ArrayList<TreeSet<IDSorter>> topicSortedWords = getSortedWords();
 		out.println("<?xml version='1.0' ?>");
 		out.println("<topicModel>");
 		for (int topic = 0; topic < numTopics; topic++) {
 			out.println("  <topic id='" + topic + "' alpha='" + alpha[topic] +
 						"' totalTokens='" + tokensPerTopic[topic] + "'>");
 			int word = 1;
-			Iterator<IDSorter> iterator = topicSortedWords[topic].iterator();
+			Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();
 			while (iterator.hasNext() && word < numWords) {
 				IDSorter info = iterator.next();
 				out.println("    <word rank='" + word + "'>" +
 		out.println("<?xml version='1.0' ?>");
 		out.println("<topics>");
 
-		TreeSet[] topicSortedWords = getSortedWords();
+		ArrayList<TreeSet<IDSorter>> topicSortedWords = getSortedWords();
 		double[] probs = new double[alphabet.size()];
 		for (int ti = 0; ti < numTopics; ti++) {
 			out.print("  <topic id=\"" + ti + "\" alpha=\"" + alpha[ti] +
 
 			// Print words
 			int word = 1;
-			Iterator<IDSorter> iterator = topicSortedWords[ti].iterator();
+			Iterator<IDSorter> iterator = topicSortedWords.get(ti).iterator();
 			while (iterator.hasNext() && word < numWords) {
 				IDSorter info = iterator.next();
 				pout.println("    <word weight=\""+(info.getWeight()/tokensPerTopic[ti])+"\" count=\""+Math.round(info.getWeight())+"\">"