Basic Image AlgorithmS Library 2.8.0

LeastSquares.cpp

00001 /* 
00002 This file is part of the BIAS library (Basic ImageAlgorithmS).
00003 
00004 Copyright (C) 2003-2009    (see file CONTACT for details)
00005   Multimediale Systeme der Informationsverarbeitung
00006   Institut fuer Informatik
00007   Christian-Albrechts-Universitaet Kiel
00008 
00009 
00010 BIAS is free software; you can redistribute it and/or modify
00011 it under the terms of the GNU Lesser General Public License as published by
00012 the Free Software Foundation; either version 2.1 of the License, or
00013 (at your option) any later version.
00014 
00015 BIAS is distributed in the hope that it will be useful,
00016 but WITHOUT ANY WARRANTY; without even the implied warranty of
00017 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00018 GNU Lesser General Public License for more details.
00019 
00020 You should have received a copy of the GNU Lesser General Public License
00021 along with BIAS; if not, write to the Free Software
00022 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
00023 */
00024 
00025 #include "LeastSquares.hh"
00026 #include "Lapack.hh"
00027 #include <Base/Common/BIASpragma.hh>
00028 
00029 using namespace BIAS;
00030 using namespace std;
00031 
00032 /////////////////////////////////////////////////////////////////////////
00033 ///           LeastSquaresBase
00034 /////////////////////////////////////////////////////////////////////////
00035 
00036 int LeastSquaresBase::Init(unsigned SolutionSize, bool ReduceToATA)
00037 {
00038   _ReduceToATA=ReduceToATA;
00039   _SolutionSize=SolutionSize;
00040   _ATA.newsize(_SolutionSize, _SolutionSize);
00041   _Weight.newsize(_SolutionSize, _SolutionSize);
00042   _ATb.newsize(_SolutionSize);
00043   return 0;
00044 }
00045 
00046 int LeastSquaresBase::Solve(Matrix<double> &A, Vector<double> &b, 
00047                             Vector<double> &x)
00048 {
00049   BIASERR("only the overloaded function in the derived class can be used");
00050   return -1;
00051 }
00052 
00053 int LeastSquaresBase::Solve(Matrix<double> &A, Vector<double> &x)
00054 {
00055   BIASERR("only the overloaded function in the derived class can be used");
00056   return -1;
00057 }
00058 
00059 int LeastSquaresBase::WeightedSolve(Matrix<double> &A, Vector<double> &b, 
00060                                     Vector<double> &weights, Vector<double> &x)
00061 {
00062   BIASERR("only the overloaded function in the derived class can be used");
00063   return -1;
00064 }
00065 
00066 int LeastSquaresBase::WeightedSolve(Matrix<double> &A, Vector<double> &weights,
00067                                     Vector<double> &x)
00068 {
00069   BIASERR("only the overloaded function in the derived class can be used");
00070   return -1;
00071 }
00072 
00073 
00074 ////////////////////////////////////////////////////////////////////////////
00075 ///             LeastSquaresLapack
00076 ///////////////////////////////////////////////////////////////////////////
00077 
00078 
00079 int LeastSquaresLapack::Solve(Matrix<double> &A, Vector<double> &b, 
00080                               Vector<double> &x)
00081 {
00082   BIASDOUT(D_LS_MATRIXES, "A: "<<A<<"\nb: "<<b);
00083   int res=0;
00084   if (_ReduceToATA){
00085     _AT=A.Transpose();
00086     _AT.Mult(A, _ATA);
00087     _AT.Mult(b, _ATb);
00088     BIASDOUT(D_LS_RED_MATRIXES, "_ATA: "<<_ATA<<"\n_ATb: "<<_ATb);
00089     x=Lapack_LLS_QR_linear_solve(_ATA, _ATb, res);
00090   } else {
00091     x=Lapack_LLS_QR_linear_solve(A, b, res);
00092   }
00093 
00094   return res;
00095 }
00096 
00097 int LeastSquaresLapack::Solve(Matrix<double> &A, Vector<double> &x)
00098 {
00099   BIASERR("unfinished");
00100   return -1;
00101 }
00102 
00103 int LeastSquaresLapack::WeightedSolve(Matrix<double> &A, Vector<double> &b, 
00104                                       Vector<double> &weights, 
00105                                       Vector<double> &x)
00106 {
00107   BIASDOUT(D_LS_MATRIXES, "A: "<<A<<"\nb: "<<b);
00108   int res=0;
00109   if (_ReduceToATA){
00110     Matrix<double> wA(A);
00111     Vector<double> wb(b);
00112     for (unsigned i=0; i<_SolutionSize; i++){
00113       wA.ScaleRow(i, weights[i]);
00114       wb[i]*=weights[i];
00115     }
00116     _AT=wA.Transpose();
00117     _AT.Mult(wA, _ATA);
00118     _AT.Mult(wb, _ATb);
00119     BIASDOUT(D_LS_RED_MATRIXES, "_ATA: "<<_ATA<<"\n_ATb: "<<_ATb);
00120     x=Lapack_LLS_QR_linear_solve(_ATA, _ATb, res);
00121   } else {
00122     _Weight.newsize(A.num_rows(), A.num_rows());
00123     for (int i=0; i<A.num_rows(); i++){
00124       _Weight[i][i]=1.0/weights[i];
00125     }
00126     x=Lapack_WLLS_solve(A, b, _Weight, res);
00127   }
00128   return res;
00129 }
00130 
00131 int LeastSquaresLapack::WeightedSolve(Matrix<double> &A, 
00132                                       Vector<double> &weights,
00133                                       Vector<double> &x)
00134 {
00135   BIASERR("unfinished");
00136   return -1;
00137 }
00138 
00139 
00140 ////////////////////////////////////////////////////////////////////////////
00141 ///             LeastSquaresSVD
00142 ///////////////////////////////////////////////////////////////////////////
00143 
00144 
00145 /** solve |Ax-b|=min using svd with mxn Matrix A
00146 
00147     |r|^2 = |Ax-b|^2 = (Ax-b)^T(Ax-b) 
00148     = x^T A^T A x - x^T A^T b - b^T A x + b^T b =
00149        use x^T A^T b = b^T A x (both are real)
00150     = x^T A^T A x - 2 x^T A^T b + b^T b =
00151        use SVD(A^T A) := W S W^T since A^T A is symmetric
00152     = x^T W S W^T x - 2 x^T A^T b + b^T b =
00153        set z:=W^T x -> x=Wz or x^T = z^T W^T
00154     = z^T S z - 2 z^T W^T A^T b + b^T b =
00155        set d:= W^T A^T b
00156     = z^T S z - 2 z^T d + b^T b =
00157     = sum_i(S_i*z_i^2 - 2 z_i d_i + b_i^2)
00158 
00159     d |r|^2 / dx = d|r|^2/dz * dz/dx == 0
00160     -> d|r|^2/dz=0
00161     -> z_i = d_i / S_i 
00162 
00163     x = W z
00164 
00165     solve least squares using svd:
00166     1. calculate SVD(A^T A) = W S W^T
00167     2. calculate d = W^T A^T b
00168     3. calculate z: z_i = d_i/S_i
00169     4. calculate x = W z
00170 
00171     @author woelk 04 2003 */
00172 int LeastSquaresSVD::Solve(Matrix<double> &A, Vector<double> &b, 
00173                            Vector<double> &x)
00174 {
00175   BIASDOUT(D_LS_MATRIXES, "A: "<<A<<"\nb: "<<b);
00176   _AT=A.Transpose();
00177   _AT.Mult(A, _ATA);
00178   _AT.Mult(b, _ATb);
00179   BIASDOUT(D_LS_RED_MATRIXES, "_ATA: "<<_ATA<<"\n_ATb: "<<_ATb);
00180   if (_svd.Compute(_ATA)!=0){
00181      return -2;
00182   }
00183   Vector<double> d(A.num_rows()), z(A.num_cols());
00184   Vector<double> S=_svd.GetS();
00185   d=_svd.GetVT() * _ATb;
00186   BIASDOUT(D_LS_SVD, "_ATA: "<<_ATA<<"\nd: "<<d<<"\nS: "<<S);
00187   for (int i=0; i<A.num_cols(); i++)
00188     z[i]=(fabs(S[i])>DEFAULT_DOUBLE_ZERO_THRESHOLD)?d[i]/S[i]:0.0;
00189   
00190   x = _svd.GetU() * z;
00191   return 0;
00192 }
00193 
00194 
00195 int LeastSquaresSVD::Solve(Matrix<double> &A, Vector<double> &x)
00196 {
00197   BIASDOUT(D_LS_MATRIXES, "A: "<<A);
00198   if (x.size()!=(int)_SolutionSize)
00199     x.newsize(_SolutionSize);
00200 
00201   if (_ReduceToATA){
00202     _AT=A.Transpose();
00203     _AT.Mult(A, _ATA);
00204     BIASDOUT(D_LS_RED_MATRIXES, "_ATA: "<<_ATA);
00205      if (General_singular_value_decomposition(_ATA, _S, _VT)!=0){
00206       BIASERR("error in svd"); 
00207       return -2;
00208     }
00209     for (unsigned i=0; i<_SolutionSize; i++){
00210       x[i]=_VT[_SolutionSize-1][i];
00211     }
00212   } else {
00213     if (General_singular_value_decomposition(A, _S, _VT)!=0){
00214       BIASERR("error in svd"); 
00215       return -2;
00216     }
00217     for (unsigned i=0; i<_SolutionSize; i++){
00218       x[i]=_VT[_SolutionSize-1][i];
00219     }
00220   }
00221   return 0;
00222 }
00223 
00224 
00225 int LeastSquaresSVD::WeightedSolve(Matrix<double> &A, Vector<double> &b, 
00226                                    Vector<double> &weights, Vector<double> &x)
00227 {
00228   BIASERR("unfinished");
00229   return -1;
00230 }
00231 
00232 int LeastSquaresSVD::WeightedSolve(Matrix<double> &A, Vector<double> &weights, 
00233                                    Vector<double> &x)
00234 {
00235   BIASDOUT(D_LS_MATRIXES, "A: "<<A<<"\nweights: "<<weights);
00236   if (x.size()!=(int)_SolutionSize)
00237     x.newsize(_SolutionSize);
00238 
00239   Matrix<double> wA(A);
00240   for (unsigned i=0; i<_SolutionSize; i++){
00241     wA.ScaleRow(i, weights[i]);
00242   }
00243 
00244   if (_ReduceToATA){
00245     _AT=wA.Transpose();
00246     _AT.Mult(wA, _ATA);
00247     BIASDOUT(D_LS_RED_MATRIXES, "_ATA: "<<_ATA);
00248     if (General_singular_value_decomposition(_ATA, _S, _VT)!=0){
00249       BIASERR("error in svd"); 
00250       return -2;
00251     }
00252     for (unsigned i=0; i<_SolutionSize; i++){
00253       x[i]=_VT[_SolutionSize-1][i];
00254     }
00255   } else {
00256     if (General_singular_value_decomposition(wA, _S, _VT)!=0){
00257       BIASERR("error in svd"); 
00258       return -2;
00259     }
00260     for (unsigned i=0; i<_SolutionSize; i++){
00261       x[i]=_VT[_SolutionSize-1][i];
00262     }
00263   }
00264   return 0;
00265 }
00266 
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends