Commits

Gregory Druck  committed 9a8ece7

Fixed bugs in ConjugateGradient.

  • Participants
  • Parent commits 48877ef

Comments (0)

Files changed (1)

File src/cc/mallet/optimize/ConjugateGradient.java

 	Optimizable.ByGradientValue optimizable;
 	LineOptimizer.ByGradient lineMaximizer;
 
-	// xxx If this is too big, we can get inconsistent value and gradient in MaxEntTrainer
-	// Investigate!!!
-	double initialStepSize = 0.01;
+	double initialStepSize = 1;
 	double tolerance = 0.0001;
+	double gradientTolerance = 0.001;
 	int maxIterations = 1000;
 
 	// "eps" is a small number to recitify the special case of converging
     this.initialStepSize = initialStepSize;
     this.optimizable = function;
     this.lineMaximizer = new BackTrackLineSearch (function);
-    // Alternative: = new GradientBracketLineMaximizer (function);
+    //Alternative:
+    //this.lineMaximizer = new GradientBracketLineOptimizer (function);
   }
 
 	public ConjugateGradient (Optimizable.ByGradientValue function)
 		if (converged)
 			return true;
     int n = optimizable.getNumParameters();
-    double prevStepSize = initialStepSize;
-    boolean searchingGradient = true;
     if (xi == null) {
 			fp = optimizable.getValue ();
 			xi = new double[n];
 
 		for (int iterationCount = 0; iterationCount < numIterations; iterationCount++) {
 			logger.info ("ConjugateGradient: At iteration "+iterations+", cost = "+fp);
-			try {
-        prevStepSize = step;
-        step = lineMaximizer.optimize (xi, step);
-			} catch (IllegalArgumentException e) {
-				System.out.println ("ConjugateGradient caught "+e.toString());
-        TestOptimizable.testValueAndGradientCurrentParameters(optimizable);
-        TestOptimizable.testValueAndGradientInDirection(optimizable, xi);
-				//System.out.println ("Trying ConjugateGradient restart.");
-				//return this.maximize (maxable, numIterations);
-			}
-      if (step == 0) {
-        if (searchingGradient) {
-          System.err.println ("ConjugateGradient converged: Line maximizer got step 0 in gradient direction.  "
-                              +"Gradient absNorm="+MatrixOps.absNorm(xi));
-          converged = true;
-          return true;
-        } else
-          System.err.println ("Line maximizer got step 0.  Probably pointing up hill.  Resetting to gradient.  "
-                              +"Gradient absNorm="+MatrixOps.absNorm(xi));
-        // Copied from above (how to code this better?  I want GoTo)
-        fp = optimizable.getValue();
-        optimizable.getValueGradient (xi);
-        searchingGradient = true;
-        System.arraycopy (xi, 0, g, 0, n);
-        System.arraycopy (xi, 0, h, 0, n);
-        step = prevStepSize;
-        continue;
-      }
+			
+      step = lineMaximizer.optimize (xi, step);
       fret = optimizable.getValue();
+      optimizable.getValueGradient(xi);
+      
 			// This termination provided by "Numeric Recipes in C".
 			if (2.0*Math.abs(fret-fp) <= tolerance*(Math.abs(fret)+Math.abs(fp)+eps)) {
-        System.out.println ("ConjugateGradient converged: old value= "+fp+" new value= "+fret+" tolerance="+tolerance);
+			  logger.info("ConjugateGradient converged: old value= "+fp+" new value= "+fret+" tolerance="+tolerance);
         converged = true;
         return true;
       }
       fp = fret;
-			optimizable.getValueGradient(xi);
-			
-			logger.info ("Gradient infinityNorm = "+MatrixOps.infinityNorm(xi));
+
 			// This termination provided by McCallum
-			if (MatrixOps.infinityNorm(xi) < tolerance) {
-        System.err.println ("ConjugateGradient converged: maximum gradient component "+MatrixOps.infinityNorm(xi)
-                            +", less than "+tolerance);
+      double twoNorm = MatrixOps.twoNorm(xi);
+			if (twoNorm < gradientTolerance) {
+        logger.info("ConjugateGradient converged: gradient two norm " + twoNorm
+                            +", less than " + gradientTolerance);
         converged = true;
         return true;
       }
 
       dgg = gg = 0.0;
-			double gj, xj;
 			for (j = 0; j < xi.length; j++) {
-				gj = g[j];
-				gg += gj * gj;
-				xj = -xi[j];
-				dgg = (xj + gj) * xj;
+				gg += g[j] * g[j];
+				dgg += xi[j] * (xi[j] - g[j]);
 			}
-			if (gg == 0.0) {
-        System.err.println ("ConjugateGradient converged: gradient is exactly zero.");
-        converged = true;
-        return true; // In unlikely case that gradient is exactly zero, then we are done
-      }
       gam = dgg/gg;
 
-			double hj;
 			for (j = 0; j < xi.length; j++) {
-				xj = xi[j];
-				g[j] = xj;
-				hj = h[j];
-				hj = xj + gam * hj;
-				h[j] = hj;
+				g[j] = xi[j];
+				h[j] = xi[j] + gam * h[j];
 			}
 			assert (!MatrixOps.isNaN(h));
-			MatrixOps.set (xi, h);
-      searchingGradient = false;
+			
+      // gdruck
+      // Mallet line search algorithms stop search whenever
+      // a step is found that increases the value significantly.  
+			// ConjugateGradient assumes that line maximization finds something close
+      // to the maximum in that direction.  In tests, sometimes the
+      // direction suggested by CG was downhill.  Consequently, here I am
+      // setting the search direction to the gradient if the slope is
+      // negative or 0.
+			if (MatrixOps.dotProduct(xi, h) > 0) {
+	      MatrixOps.set (xi, h);
+			}
+			else {
+			  logger.warning("Reverting back to GA");
+			  MatrixOps.set (h, xi);
+			}
 
       iterations++;
 			if (iterations > maxIterations) {
-				System.err.println("Too many iterations in ConjugateGradient.java");
+				logger.info("Too many iterations in ConjugateGradient.java");
 				converged = true;
 				return true;
-				//throw new IllegalStateException ("Too many iterations.");
 			}
 
-      if (eval != null)
+      if (eval != null) {
         eval.evaluate (optimizable, iterations);
+      }
     }
 		return false;
 	}