Skip to content

Commit 84bb42c

Browse files
committed
Implicit Weights
The weights are no longer implicit in LeastSquaresProblem.Evaluation. They are already included in the computed residuals and Jacobian. GN and LM multiplied the residuals by the weights immediately, so that was easy to remove. Created an AbstractEvaluation class which handles the derived quantitied (cost, rms, covariance,...) and two implementations. UnweightedEvaluation uses the straight forward formulas. DenseWeightedEvaluation delegates to an Evaluation and multiples the residuals and Jacobian by the square root of the weight matrix before returning them. Allowed me to remove the reference to the full weight matrix.
1 parent 157cdde commit 84bb42c

9 files changed

Lines changed: 319 additions & 212 deletions
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package org.apache.commons.math3.fitting.leastsquares;
2+
3+
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
4+
import org.apache.commons.math3.linear.ArrayRealVector;
5+
import org.apache.commons.math3.linear.DecompositionSolver;
6+
import org.apache.commons.math3.linear.QRDecomposition;
7+
import org.apache.commons.math3.linear.RealMatrix;
8+
import org.apache.commons.math3.util.FastMath;
9+
10+
/**
11+
* An implementation of {@link Evaluation} that is designed for extension. All of the
12+
* methods implemented here use the methods that are left unimplemented.
13+
* <p/>
14+
* TODO cache results?
15+
*
16+
* @version $Id$
17+
*/
18+
abstract class AbstractEvaluation implements Evaluation {
19+
20+
/** number of observations */
21+
private final int observationSize;
22+
23+
/**
24+
* Constructor.
25+
*
26+
* @param observationSize the number of observation. Needed for {@link
27+
* #computeRMS()}.
28+
*/
29+
AbstractEvaluation(final int observationSize) {
30+
this.observationSize = observationSize;
31+
}
32+
33+
/** {@inheritDoc} */
34+
public double[][] computeCovariances(double threshold) {
35+
// Set up the Jacobian.
36+
final RealMatrix j = this.computeJacobian();
37+
38+
// Compute transpose(J)J.
39+
final RealMatrix jTj = j.transpose().multiply(j);
40+
41+
// Compute the covariances matrix.
42+
final DecompositionSolver solver
43+
= new QRDecomposition(jTj, threshold).getSolver();
44+
return solver.getInverse().getData();
45+
}
46+
47+
/** {@inheritDoc} */
48+
public double[] computeSigma(double covarianceSingularityThreshold) {
49+
final double[][] cov = this.computeCovariances(covarianceSingularityThreshold);
50+
final int nC = cov.length;
51+
final double[] sig = new double[nC];
52+
for (int i = 0; i < nC; ++i) {
53+
sig[i] = FastMath.sqrt(cov[i][i]);
54+
}
55+
return sig;
56+
}
57+
58+
/** {@inheritDoc} */
59+
public double computeRMS() {
60+
final double cost = this.computeCost();
61+
return FastMath.sqrt(cost * cost / this.observationSize);
62+
}
63+
64+
/** {@inheritDoc} */
65+
public double computeCost() {
66+
final ArrayRealVector r = new ArrayRealVector(this.computeResiduals());
67+
return FastMath.sqrt(r.dotProduct(r));
68+
}
69+
70+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package org.apache.commons.math3.fitting.leastsquares;
2+
3+
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
4+
import org.apache.commons.math3.linear.RealMatrix;
5+
6+
/**
7+
* Applies a dense weight matrix to an evaluation.
8+
*
9+
* @version $Id$
10+
*/
11+
class DenseWeightedEvaluation extends AbstractEvaluation {
12+
13+
/** the unweighted evaluation */
14+
private final Evaluation unweighted;
15+
/** reference to the weight square root matrix */
16+
private final RealMatrix weightSqrt;
17+
18+
/**
19+
* Create a weighted evaluation from an unweighted one.
20+
*
21+
* @param unweighted the evalutation before weights are applied
22+
* @param weightSqrt the matrix square root of the weight matrix
23+
*/
24+
DenseWeightedEvaluation(final Evaluation unweighted,
25+
final RealMatrix weightSqrt) {
26+
// weight square root is square, nR=nC=number of observations
27+
super(weightSqrt.getColumnDimension());
28+
this.unweighted = unweighted;
29+
this.weightSqrt = weightSqrt;
30+
}
31+
32+
/* apply weights */
33+
34+
/** {@inheritDoc} */
35+
public RealMatrix computeJacobian() {
36+
return weightSqrt.multiply(this.unweighted.computeJacobian());
37+
}
38+
39+
/** {@inheritDoc} */
40+
public double[] computeResiduals() {
41+
return this.weightSqrt.operate(this.unweighted.computeResiduals());
42+
}
43+
44+
/* delegate */
45+
46+
/** {@inheritDoc} */
47+
public double[] getPoint() {
48+
return unweighted.getPoint();
49+
}
50+
51+
/** {@inheritDoc} */
52+
public double[] computeValue() {
53+
return unweighted.computeValue();
54+
}
55+
}

src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,10 @@ public Optimum optimize(final LeastSquaresProblem lsp) {
104104
throw new NullArgumentException();
105105
}
106106

107-
final RealMatrix weightMatrix = lsp.getWeight();
108-
final int nR = weightMatrix.getRowDimension(); // Number of observed data.
109-
110-
// Diagonal of the weight matrix.
111-
final double[] residualsWeights = new double[nR];
112-
for (int i = 0; i < nR; i++) {
113-
residualsWeights[i] = weightMatrix.getEntry(i, i);
114-
}
107+
final int nR = lsp.getObservationSize(); // Number of observed data.
108+
final int nC = lsp.getParameterSize();
115109

116110
final double[] currentPoint = lsp.getStart();
117-
final int nC = currentPoint.length;
118111

119112
// iterate until convergence is reached
120113
PointVectorValuePair current = null;
@@ -128,7 +121,7 @@ public Optimum optimize(final LeastSquaresProblem lsp) {
128121
final Evaluation value = lsp.evaluate(currentPoint);
129122
final double[] currentObjective = value.computeValue();
130123
final double[] currentResiduals = value.computeResiduals();
131-
final RealMatrix weightedJacobian = value.computeWeightedJacobian();
124+
final RealMatrix weightedJacobian = value.computeJacobian();
132125
current = new PointVectorValuePair(currentPoint, currentObjective);
133126

134127
// build the linear problem
@@ -137,21 +130,20 @@ public Optimum optimize(final LeastSquaresProblem lsp) {
137130
for (int i = 0; i < nR; ++i) {
138131

139132
final double[] grad = weightedJacobian.getRow(i);
140-
final double weight = residualsWeights[i];
141133
final double residual = currentResiduals[i];
142134

143135
// compute the normal equation
144-
final double wr = weight * residual;
136+
//residual is already weighted
145137
for (int j = 0; j < nC; ++j) {
146-
b[j] += wr * grad[j];
138+
b[j] += residual * grad[j];
147139
}
148140

149141
// build the contribution matrix for measurement i
150142
for (int k = 0; k < nC; ++k) {
151143
double[] ak = a[k];
152-
double wgk = weight * grad[k];
144+
//Jacobian/gradient is already weighted
153145
for (int l = 0; l < nC; ++l) {
154-
ak[l] += wgk * grad[l];
146+
ak[l] += grad[k] * grad[l];
155147
}
156148
}
157149
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package org.apache.commons.math3.fitting.leastsquares;
2+
3+
import org.apache.commons.math3.optim.ConvergenceChecker;
4+
import org.apache.commons.math3.optim.PointVectorValuePair;
5+
import org.apache.commons.math3.util.Incrementor;
6+
7+
/**
8+
* An adapter that delegates to another implementation of {@link LeastSquaresProblem}.
9+
*
10+
* @version $Id$
11+
*/
12+
public class LeastSquaresAdapter implements LeastSquaresProblem {
13+
14+
/** the delegate problem */
15+
private final LeastSquaresProblem problem;
16+
17+
/**
18+
* Delegate the {@link LeastSquaresProblem} interface to the given implementation.
19+
*
20+
* @param problem the delegate
21+
*/
22+
public LeastSquaresAdapter(final LeastSquaresProblem problem) {
23+
this.problem = problem;
24+
}
25+
26+
/** {@inheritDoc} */
27+
public double[] getStart() {
28+
return problem.getStart();
29+
}
30+
31+
/** {@inheritDoc} */
32+
public int getObservationSize() {
33+
return problem.getObservationSize();
34+
}
35+
36+
/** {@inheritDoc} */
37+
public int getParameterSize() {
38+
return problem.getParameterSize();
39+
}
40+
41+
/** {@inheritDoc} */
42+
public Evaluation evaluate(final double[] point) {
43+
return problem.evaluate(point);
44+
}
45+
46+
/** {@inheritDoc} */
47+
public Incrementor getEvaluationCounter() {
48+
return problem.getEvaluationCounter();
49+
}
50+
51+
/** {@inheritDoc} */
52+
public Incrementor getIterationCounter() {
53+
return problem.getIterationCounter();
54+
}
55+
56+
/** {@inheritDoc} */
57+
public ConvergenceChecker<PointVectorValuePair> getConvergenceChecker() {
58+
return problem.getConvergenceChecker();
59+
}
60+
}

0 commit comments

Comments
 (0)