Commits

Anonymous committed 06f45e2

Fixed so that a new ClassifierTrainer is created in each trial, preventing unintended initialization.

  • Participants
  • Parent commits 7ec68fb

Comments (0)

Files changed (1)

src/cc/mallet/classify/tui/Vectors2Classify.java

 
 import java.io.*;
 import java.util.*;
-import java.util.Random;
 import java.util.logging.*;
 import java.lang.reflect.*;
 
 
 public abstract class Vectors2Classify
 {
+  static BshInterpreter interpreter = new BshInterpreter();
+  
 	private static Logger logger = MalletLogger.getLogger(Vectors2Classify.class.getName());
 	private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Vectors2Classify.class.getName() + "-pl");
-	private static ArrayList<ClassifierTrainer> classifierTrainers = new ArrayList<ClassifierTrainer>();
-    private static boolean[][] ReportOptions = new boolean[3][4];
-    private static String[][] ReportOptionArgs = new String[3][4];  //arg in dataset:reportOption=arg
+	private static ArrayList<String> classifierTrainerStrings = new ArrayList<String>();
+  private static boolean[][] ReportOptions = new boolean[3][4];
+  private static String[][] ReportOptionArgs = new String[3][4];  //arg in dataset:reportOption=arg
 	// Essentially an enum mapping string names to enums to ints.
 	private static class ReportOption
 	{
 			}
 		};
 
-	static CommandOption.Object trainerConstructor = new CommandOption.Object
-		(Vectors2Classify.class, "trainer", "ClassifierTrainer constructor",	true, new NaiveBayesTrainer(),
-		 "Java code for the constructor used to create a ClassifierTrainer.  "+
-		 "If no '(' appears, then \"new \" will be prepended and \"Trainer()\" will be appended."+
-		 "You may use this option mutiple times to compare multiple classifiers.", null)
-		{
-			public void parseArg (java.lang.String arg) {
-				// parse something like Maxent,gaussianPriorVariance=10,numIterations=20
-				//System.out.println("Arg = " + arg);
-
-                // first, split the argument at commas.
-				java.lang.String fields[] = arg.split(",");
-
-				//Massage constructor name, so that MaxEnt, MaxEntTrainer, new MaxEntTrainer()
-				// all call new MaxEntTrainer()
-				java.lang.String constructorName = fields[0];
-				if (constructorName.indexOf('(') != -1)     // if contains (), pass it though
-					super.parseArg(arg);
-				else {
-					if (constructorName.endsWith("Trainer")){
-						super.parseArg("new " + constructorName + "()"); // add parens if they forgot
-					}else{
-						super.parseArg("new "+constructorName+"Trainer()"); // make trainer name from classifier name
-					}
-				}
-
-				// find methods associated with the class we just built
-				Method methods[] =  this.value.getClass().getMethods();
-
-				// find setters corresponding to parameter names.
-				for (int i=1; i<fields.length; i++){
-					java.lang.String nameValuePair[] = fields[i].split("=");
-					java.lang.String parameterName  = nameValuePair[0];
-					java.lang.String parameterValue = nameValuePair[1];  //todo: check for val present!
-					java.lang.Object parameterValueObject;
-					try {
-						parameterValueObject = getInterpreter().eval(parameterValue);
-					} catch (bsh.EvalError e) {
-						throw new IllegalArgumentException ("Java interpreter eval error on parameter "+
-						                                    parameterName + "\n"+e);
-					}
-
-					boolean foundSetter = false;
-					for (int j=0; j<methods.length; j++){
-						//						System.out.println("method " + j + " name is " + methods[j].getName());
-						//						System.out.println("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1));
-						if ( ("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1)).equals(methods[j].getName()) &&
-							 methods[j].getParameterTypes().length == 1){
-							//							System.out.println("Matched method " + methods[j].getName());
-							//							Class[] ptypes = methods[j].getParameterTypes();
-							//							System.out.println("Parameter types:");
-							//							for (int k=0; k<ptypes.length; k++){
-							//								System.out.println("class " + k + " = " + ptypes[k].getName());
-							//							}
-
-							try {
-								java.lang.Object[] parameterList = new java.lang.Object[]{parameterValueObject};
-								//								System.out.println("Argument types:");
-								//								for (int k=0; k<parameterList.length; k++){
-								//									System.out.println("class " + k + " = " + parameterList[k].getClass().getName());
-								//								}
-								methods[j].invoke(this.value, parameterList);
-							} catch ( IllegalAccessException e) {
-								System.out.println("IllegalAccessException " + e);
-								throw new IllegalArgumentException ("Java access error calling setter\n"+e);
-							}  catch ( InvocationTargetException e) {
-								System.out.println("IllegalTargetException " + e);
-								throw new IllegalArgumentException ("Java target error calling setter\n"+e);
-							}
-							foundSetter = true;
-							break;
-						}
-					}
-					if (!foundSetter){
-		                System.out.println("Parameter " + parameterName + " not found on trainer " + constructorName);
-						System.out.println("Available parameters for " + constructorName);
-						for (int j=0; j<methods.length; j++){
-							if ( methods[j].getName().startsWith("set") && methods[j].getParameterTypes().length == 1){
-								System.out.println(Character.toLowerCase(methods[j].getName().charAt(3)) +
-								                   methods[j].getName().substring(4));
-							}
-						}
-
-						throw new IllegalArgumentException ("no setter found for parameter " + parameterName);
-					}
-				}
-
-			}
-			public void postParsing (CommandOption.List list) {
-				assert (this.value instanceof ClassifierTrainer);
-				//System.out.println("v2c PostParsing " + this.value);
-				classifierTrainers.add ((ClassifierTrainer)this.value);
-			}
-		};
+		
+		static CommandOption.String trainerConstructor = new CommandOption.String
+      (Vectors2Classify.class, "trainer", "ClassifierTrainer constructor",  true, "new NaiveBayesTrainer()",
+        "Java code for the constructor used to create a ClassifierTrainer.  "+
+        "If no '(' appears, then \"new \" will be prepended and \"Trainer()\" will be appended."+
+        "You may use this option mutiple times to compare multiple classifiers.", null) {
+      public void postParsing (CommandOption.List list) {
+        classifierTrainerStrings.add (this.value);
+      }};
 
 	static CommandOption.String outputFile = new CommandOption.String
 		(Vectors2Classify.class, "output-classifier", "FILENAME", true, "classifier.mallet",
 
 		// handle default trainer here for now; default argument processing doesn't  work
 		if (!trainerConstructor.wasInvoked()){
-			classifierTrainers.add (new NaiveBayesTrainer());
+			classifierTrainerStrings.add ("new NaiveBayesTrainer()");
 		}
 
 		if (!report.wasInvoked()){
 		int numTrials = numTrialsOption.value;
 		Random r = randomSeedOption.wasInvoked() ? new Random (randomSeedOption.value) : new Random ();
 
-		ClassifierTrainer[] trainers = new ClassifierTrainer[classifierTrainers.size()];
-		for (int i = 0; i < classifierTrainers.size(); i++) {
-			trainers[i] = classifierTrainers.get(i);
-			logger.fine ("Trainer specified = "+trainers[i].toString());
-		}
+		int numTrainers = classifierTrainerStrings.size();
 
-		double trainAccuracy[][] = new double[trainers.length][numTrials];
-		double testAccuracy[][] = new double[trainers.length][numTrials];
-		double validationAccuracy[][] = new double[trainers.length][numTrials];
+		double trainAccuracy[][] = new double[numTrainers][numTrials];
+		double testAccuracy[][] = new double[numTrainers][numTrials];
+		double validationAccuracy[][] = new double[numTrainers][numTrials];
 
-		String trainConfusionMatrix[][] = new String[trainers.length][numTrials];
-		String testConfusionMatrix[][] = new String[trainers.length][numTrials];
-		String validationConfusionMatrix[][] = new String[trainers.length][numTrials];
+		String trainConfusionMatrix[][] = new String[numTrainers][numTrials];
+		String testConfusionMatrix[][] = new String[numTrainers][numTrials];
+		String validationConfusionMatrix[][] = new String[numTrainers][numTrials];
 
 		double t = trainingProportionOption.value;
 		double v = validationProportionOption.value;
 		//			System.out.println();
 		//		}
 
+    String[] trainerNames = new String[numTrainers];
 		for (int trialIndex = 0; trialIndex < numTrials; trialIndex++) {
 			System.out.println("\n-------------------- Trial " + trialIndex + "  --------------------\n");
 			InstanceList[] ilists;
 
 
 			//System.out.println ("Training with "+ilists[0].size()+" instances");
-			long time[] = new long[trainers.length];
-			for (int c = 0; c < trainers.length; c++){
+			long time[] = new long[numTrainers];
+			for (int c = 0; c < numTrainers; c++){
 				time[c] = System.currentTimeMillis();
-				System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " with "+ilists[0].size()+" instances");
+        ClassifierTrainer trainer = getTrainer(classifierTrainerStrings.get(c));
+        trainer.setValidationInstances(ilists[2]);
+				System.out.println ("Trial " + trialIndex + " Training " + trainer + " with "+ilists[0].size()+" instances");
 				if (unlabeledProportionOption.value > 0)
 					ilists[0].hideSomeLabels(unlabeledIndices);
-				trainers[c].setValidationInstances(ilists[2]);
-				Classifier classifier = trainers[c].train (ilists[0]);
+				Classifier classifier = trainer.train (ilists[0]);
 				if (unlabeledProportionOption.value > 0)
 					ilists[0].unhideAllLabels();
 
-				System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " finished");
+				System.out.println ("Trial " + trialIndex + " Training " + trainer.toString() + " finished");
 				time[c] = System.currentTimeMillis() - time[c];
 				Trial trainTrial = new Trial (classifier, ilists[0]);
 				//assert (ilists[1].size() > 0);
 
 				if (outputFile.wasInvoked()) {
 					String filename = outputFile.value;
-					if (trainers.length > 1) filename = filename+trainers[c].toString();
+					if (numTrainers > 1) filename = filename+trainer.toString();
 					if (numTrials > 1) filename = filename+".trial"+trialIndex;
 					try {
 						ObjectOutputStream oos = new ObjectOutputStream
 
                 // raw output
 				if (ReportOptions[ReportOption.train][ReportOption.raw]){
-					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
+					System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
 					System.out.println(" Raw Training Data");
 					printTrialClassification(trainTrial);
 				}
 
 				if (ReportOptions[ReportOption.test][ReportOption.raw]){
-					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
+					System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
 					System.out.println(" Raw Testing Data");
 					printTrialClassification(testTrial);
 				}
 
 				if (ReportOptions[ReportOption.validation][ReportOption.raw]){
-					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
+					System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
 					System.out.println(" Raw Validation Data");
 					printTrialClassification(validationTrial);
 				}
 
 				//train
 				if (ReportOptions[ReportOption.train][ReportOption.confusion]){
-					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() +  " Training Data Confusion Matrix");
+					System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() +  " Training Data Confusion Matrix");
 					if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]);
 				}
 
 				if (ReportOptions[ReportOption.train][ReportOption.accuracy]){
-					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data accuracy= "+ trainAccuracy[c][trialIndex]);
+					System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data accuracy= "+ trainAccuracy[c][trialIndex]);
 				}
 
 				if (ReportOptions[ReportOption.train][ReportOption.f1]){
 					String label = ReportOptionArgs[ReportOption.train][ReportOption.f1];
-					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data F1(" + label + ") = "+ trainTrial.getF1(label));
+					System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data F1(" + label + ") = "+ trainTrial.getF1(label));
 				}
 
 				//validation
 				if (ReportOptions[ReportOption.validation][ReportOption.confusion]){
-					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() +  " Validation Data Confusion Matrix");
+					System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() +  " Validation Data Confusion Matrix");
 					if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]);
 				}
 
 				if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){
-					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data accuracy= "+ validationAccuracy[c][trialIndex]);
+					System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data accuracy= "+ validationAccuracy[c][trialIndex]);
 				}
 
 				if (ReportOptions[ReportOption.validation][ReportOption.f1]){
 					String label = ReportOptionArgs[ReportOption.validation][ReportOption.f1];
-					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data F1(" + label + ") = "+ validationTrial.getF1(label));
+					System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data F1(" + label + ") = "+ validationTrial.getF1(label));
 				}
 
 				//test
 				if (ReportOptions[ReportOption.test][ReportOption.confusion]){
-					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Test Data Confusion Matrix");
+					System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() + " Test Data Confusion Matrix");
 					if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]);
 				}
 
 				if (ReportOptions[ReportOption.test][ReportOption.accuracy]){
-					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data accuracy= "+ testAccuracy[c][trialIndex]);
+					System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data accuracy= "+ testAccuracy[c][trialIndex]);
 				}
 
 				if (ReportOptions[ReportOption.test][ReportOption.f1]){
 					String label = ReportOptionArgs[ReportOption.test][ReportOption.f1];
-					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data F1(" + label + ") = "+ testTrial.getF1(label));
+					System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data F1(" + label + ") = "+ testTrial.getF1(label));
 				}
-
+				
+				if (trialIndex == 0) trainerNames[c] = trainer.toString();
 
 			}  // end for each trainer
 		}  // end for each trial
 
         // New reporting
 		//"[train|test|validation]:[accuracy|f1|confusion|raw]"
-		for (int c=0; c < trainers.length; c++) {
-			System.out.println ("\n"+trainers[c].toString());
+		for (int c=0; c < numTrainers; c++) {
+			System.out.println ("\n"+trainerNames[c].toString());
 			if (ReportOptions[ReportOption.train][ReportOption.accuracy])
 				System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+
 									" stddev = "+ MatrixOps.stddev (trainAccuracy[c])+
 
 	private static void printTrialClassification(Trial trial)
 	{
-		for (Classification c : trial) {
-			Instance instance = c.getInstance();
-			System.out.print(instance.getName() + " " + instance.getTarget() + " ");
-			Labeling labeling = c.getLabeling();
-			for (int j = 0; j < labeling.numLocations(); j++){
-				System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
-			}
-			System.out.println();
-		}
+	  for (Classification c : trial) {
+	    Instance instance = c.getInstance();
+	    System.out.print(instance.getName() + " " + instance.getTarget() + " ");
+	    Labeling labeling = c.getLabeling();
+	    for (int j = 0; j < labeling.numLocations(); j++){
+	      System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
+	    }
+	    System.out.println();
+	  }
 	}
 
+	private static Object createTrainer(String arg) {
+	  try {
+	    return interpreter.eval (arg);
+	  } catch (bsh.EvalError e) {
+	    throw new IllegalArgumentException ("Java interpreter eval error\n"+e);
+	  }
+	}
 
+	private static ClassifierTrainer getTrainer(String arg) {
+	  // parse something like Maxent,gaussianPriorVariance=10,numIterations=20
+
+	  // first, split the argument at commas.
+	  java.lang.String fields[] = arg.split(",");
+
+	  //Massage constructor name, so that MaxEnt, MaxEntTrainer, new MaxEntTrainer()
+	  // all call new MaxEntTrainer()
+	  java.lang.String constructorName = fields[0];
+	  Object trainer;
+	  if (constructorName.indexOf('(') != -1) // if contains (), pass it though
+	    trainer = createTrainer(arg);
+	  else {
+	    if (constructorName.endsWith("Trainer")){
+	      trainer = createTrainer("new " + constructorName + "()"); // add parens if they forgot
+	    }else{
+	      trainer = createTrainer("new "+constructorName+"Trainer()"); // make trainer name from classifier name
+	    }
+	  }
+
+	  // find methods associated with the class we just built
+	  Method methods[] =  trainer.getClass().getMethods();
+
+	  // find setters corresponding to parameter names.
+	  for (int i=1; i<fields.length; i++){
+	    java.lang.String nameValuePair[] = fields[i].split("=");
+	    java.lang.String parameterName  = nameValuePair[0];
+	    java.lang.String parameterValue = nameValuePair[1];  //todo: check for val present!
+	    java.lang.Object parameterValueObject;
+	    try {
+	      parameterValueObject = interpreter.eval(parameterValue);
+	    } catch (bsh.EvalError e) {
+	      throw new IllegalArgumentException ("Java interpreter eval error on parameter "+
+	          parameterName + "\n"+e);
+	    }
+
+	    boolean foundSetter = false;
+	    for (int j=0; j<methods.length; j++){
+	      // System.out.println("method " + j + " name is " + methods[j].getName());
+	      // System.out.println("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1));
+	      if ( ("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1)).equals(methods[j].getName()) &&
+	          methods[j].getParameterTypes().length == 1){
+	        // System.out.println("Matched method " + methods[j].getName());
+	        // Class[] ptypes = methods[j].getParameterTypes();
+	        // System.out.println("Parameter types:");
+	        // for (int k=0; k<ptypes.length; k++){
+	        // System.out.println("class " + k + " = " + ptypes[k].getName());
+	        // }
+
+	        try {
+	          java.lang.Object[] parameterList = new java.lang.Object[]{parameterValueObject};
+	          // System.out.println("Argument types:");
+	          // for (int k=0; k<parameterList.length; k++){
+	          // System.out.println("class " + k + " = " + parameterList[k].getClass().getName());
+	          // }
+	          methods[j].invoke(trainer, parameterList);
+	        } catch ( IllegalAccessException e) {
+	          System.out.println("IllegalAccessException " + e);
+	          throw new IllegalArgumentException ("Java access error calling setter\n"+e);
+	        }  catch ( InvocationTargetException e) {
+	          System.out.println("IllegalTargetException " + e);
+	          throw new IllegalArgumentException ("Java target error calling setter\n"+e);
+	        }
+	        foundSetter = true;
+	        break;
+	      }
+	    }
+	    if (!foundSetter){
+	      System.out.println("Parameter " + parameterName + " not found on trainer " + constructorName);
+	      System.out.println("Available parameters for " + constructorName);
+	      for (int j=0; j<methods.length; j++){
+	        if ( methods[j].getName().startsWith("set") && methods[j].getParameterTypes().length == 1){
+	          System.out.println(Character.toLowerCase(methods[j].getName().charAt(3)) +
+	              methods[j].getName().substring(4));
+	        }
+	      }
+
+	      throw new IllegalArgumentException ("no setter found for parameter " + parameterName);
+	    }
+	  }
+	  assert (trainer instanceof ClassifierTrainer);
+	  return ((ClassifierTrainer)trainer);
+	}
 }