/********************************************************************************
*                                                                               *
*  TPClib 0.9 Medical imaging library                                           *
*  Copyright (C) 2011 Turku PET Centre                                          *
*                                                                               *
*  This library is free software: you can redistribute it and/or modify it      *
*  under the terms of the GNU Lesser General Public License (LGPL) as           *
*  published by the Free Software Foundation, either version 2.1 of the         *
*  License, or (at your option) any later version.                              *
*                                                                               *
*  This library is distributed in the hope that it will be useful, but          *
*  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY   *
*  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public      *
*  License for more details.                                                    *
*                                                                               *
*  You should have received a copy of the GNU Lesser General Public License     *
*  along with this program.  If not, see <http://www.gnu.org/licenses/>.        *
*                                                                               *
********************************************************************************/

using System;
using TPClib.Common;

namespace TPClib.Modeling
{
    /// <summary>
    /// Collection of dataFunctions related to solving linear equations
    /// </summary>
    public static class LinearEquations
    {
        /// <summary>
        /// Gauss-Jordan method for general, non-square matrices.
        /// Solves the linear equation system Ax=b --> A'Ax=A'b
        /// (A' is the transpose of A)
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector GaussJordan(Matrix a, Vector b)
        {
            if (a.Rows < a.Columns) throw new LinearException("System is underdetermined");
            Matrix at = Matrix.Transpose(a);
            Matrix r = at * a;
            Matrix c = Matrix.Transpose((Matrix)(at * b));
            GaussJordan(ref r, r.Rows, ref c, 1);
            return c.GetColumnVector(0);
        }

        /// <summary>
        /// Solves the systems of linear equations AX=B
        /// with Gauss-Jordan elimination (with pivoting).
        /// The solution matrix X overwrites B.
        /// </summary>
        /// <param name="A">Coefficient matrix</param>
        /// <param name="n">Number of rows processed</param>
        /// <param name="B">The result vectors, one vector per column</param>
        /// <param name="m">1st column index that is in B</param>
        public static void GaussJordan(ref Matrix A, int n, ref Matrix B, int m)
        {
            if (n < A.Columns) throw new LinearException("System is underdetermined");
            if (n != B.Rows) throw new LinearException("Dimension mismatch");
            if (m > B.Columns || m <= 0) throw new LinearException("Dimension mismatch");

            int icol = -1, irow = -1;
            double big, dum, pivinv, temp;

            // These three are used for bookkeeping on the pivoting.
            int[] indxc = new int[n];
            int[] indxr = new int[n];
            int[] ipiv = new int[n];

            // This is the main loop over
            // the columns to be reduced.
            for (int i = 0; i < n; i++)
            {
                big = 0.0;

                // This is the outer loop of the search for a pivot element.
                for (int j = 0; j < n; j++)
                {
                    if (ipiv[j] != 1)
                    {
                        for (int k = 0; k < n; k++)
                        {
                            if (ipiv[k] == 0)
                            {
                                double abs = Math.Abs(A[j, k]);
                                if (abs >= big)
                                {
                                    big = abs;
                                    irow = j;
                                    icol = k;
                                }
                            }
                            else if (ipiv[k] > 1)
                            {
                                throw new LinearException("Singular Matrix");
                            }
                        }
                    }
                }

                ipiv[icol]++;

                // We now have ...

                if (irow != icol)
                {
                    for (int l = 0; l < n; l++)
                    {
                        temp = A[irow, l];
                        A[irow, l] = A[icol, l];
                        A[icol, l] = temp;
                    }
                    for (int l = 0; l < m; l++)
                    {
                        temp = B[irow, l];
                        B[irow, l] = B[icol, l];
                        B[icol, l] = temp;
                    }
                }

                // We are now ready to divide the pivot row by the
                // pivot element, located at irow and icol.
                indxr[i] = irow;
                indxc[i] = icol;

                if (A[icol, icol] == 0.0)
                {
                    throw new LinearException("gaussj: Singular Matrix-2");
                }

                pivinv = 1.0 / A[icol, icol];

                A[icol, icol] = 1.0;

                for (int l = 0; l < n; l++)
                {
                    A[icol, l] *= pivinv;
                }

                for (int l = 0; l < m; l++)
                {
                    B[icol, l] *= pivinv;
                }

                // Next, we reduce the rows...
                //...except for the pivot one, of course.
                for (int ll = 0; ll < n; ll++)
                {
                    if (ll != icol)
                    {
                        dum = A[ll, icol];
                        A[ll, icol] = 0.0;

                        for (int l = 0; l < n; l++)
                        {
                            A[ll, l] -= A[icol, l] * dum;
                        }
                        for (int l = 0; l < m; l++)
                        {
                            B[ll, l] -= B[icol, l] * dum;
                        }
                    }
                }

            }

            // This is the end of the main loop over
            // columns of the reduction. It only remains to unscramble
            // the solution in view of the column interchanges.
            // We do this by interchanging pairs of
            // columns in the reverse order that
            // the permutation was built up.
            for (int l = n - 1; l >= 0; l--)
            {
                if (indxr[l] != indxc[l])
                {
                    for (int k = 0; k < n; k++)
                    {
                        temp = A[k, indxr[l]];
                        A[k, indxr[l]] = A[k, indxc[l]];
                        A[k, indxc[l]] = temp;
                    }
                }
            }

            // And we are done.
        }

        /// <summary>
        /// Solves the linear system Ax=b, where A is a upper triangular matrix.
        /// If A is m*n matrix with n>m, gives the solution where the last n-m
        /// variables are 0.
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector BackSub(Matrix a, Vector b)
        {
            // If A and b are of different dimensions, throw exception
            if (b.Length != a.Rows)
                throw new LinearException("Dimension mismatch");

            int n = a.Columns;
            Vector x = new Vector(n);

            // If system is underdetermined, i.e. m = a.Rows < a.Columns = n,
            // set the last n - m variables to zero and solve the remaining system.
            if (a.Columns > a.Rows) n = a.Rows;

            // Solve the system, starting from bottom. If values of zero or near zero (causing
            // NaN or Infinity) are encountered, set those variables to zero in the solution
            for (int i = n - 1; i >= 0; i--)
            {
                double s = Vector.Range(a.GetRowVector(i), i + 1, n - 1 - i) * Vector.Range(x, i + 1, n - 1 - i);
                x[i] = (b[i] - s) / a[i, i];
                if (Double.IsInfinity(x[i]) || Double.IsNaN(x[i])) x[i] = 0;
            }

            return x;
        }

        /// <summary>
        /// Solves the linear system Ax=b, where A is a lower triangular matrix.
        /// If A is m*n matrix with n>m, gives the solution where the last n-m
        /// variables are 0.
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector ForwardSub(Matrix a, Vector b)
        {
            // If A and b are of different dimensions, throw exception
            if (b.Length != a.Rows)
                throw new LinearException("Dimension mismatch");

            int n = a.Columns;
            Vector x = new Vector(n);

            // If system is underdetermined, i.e. m = a.Rows < a.Columns = n,
            // set the last n - m variables to zero and solve the remaining system.
            if (a.Columns > a.Rows) n = a.Rows;

            // Solve the system, starting from top. If values of zero or near zero (causing
            // NaN or Infinity) are encountered, set those variables to zero in the solution
            for (int i = 0; i < n; i++)
            {
                double s = Vector.Range(a.GetRowVector(i), 0, i) * Vector.Range(x, 0, i);
                x[i] = (b[i] - s) / a[i, i];
                if (Double.IsInfinity(x[i]) || Double.IsNaN(x[i])) x[i] = 0;
            }

            return x;
        }

        /// <summary>
        /// Solves a general system of linear equations with Cholesky decomposition.
        /// A must be a MxN matrix with M >= N and of rank N.
        /// 
        /// If this is the case, A'A (where A' is the transpose of A) is symmetric and
        /// positive definite, and we can solve Ax=b by solving the equivalent system
        /// A'Ax=A'b with Cholesky decomposition (see CholeskySolve).
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector CholeskySolveGen(Matrix a, Vector b)
        {
            if (a.Rows < a.Columns)
                throw new LinearException("Rows must be greater or equal than columns for Cholesky decomposition.");
            Matrix at = Matrix.Transpose(a);
            return CholeskySolve(at * a, at * b);
        }

        /// <summary>
        /// Solves a linear equation system Ax=b, if A is symmetric and positive definite.
        /// (The normal equations of linear least squares problem are of this form)
        /// 
        /// Decomposing A, which is symmetric and positive definite, we get LL'x=b.
        /// Mark z=L'x. Now solve Lz=b, where L is lower triangular.
        /// Then solve L'x=z, where L' is upper triangular to get solution x.
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector CholeskySolve(Matrix a, Vector b)
        {
            Matrix r = Cholesky(a);
            Matrix rt = Matrix.Transpose(r);
            Vector z = ForwardSub(r, b);
            Vector x = BackSub(rt, z);

            return x;
        }

        /// <summary>
        /// Householder transformation matrix
        /// </summary>
        /// <param name="x">Normal vector of the reflecting plane</param>
        /// <returns>Householder transformation matrix</returns>
        public static Matrix Householder(Vector x)
        {
            int n = x.Length;
			double norm = Vector.Norm(x);
            Matrix q = Matrix.Identity(n);

            if (norm > 0.0)
            {
				Matrix v = (Matrix)x / norm;

				Matrix vt = Matrix.Transpose(v);

                q = q - (v * vt) * 2;
            }
            return q;
        }

        /// <summary>
        /// Solves a system of linear equations with QR decomposition.
        /// 
        /// (Transpose of A is marked with A')
        /// From Ax=b decomposition gives QRx=b -> Rx=Q'b.
        /// R is upper triangular, so this is easily solved by backsubstitution.
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector QRSolve(Matrix a, Vector b)
        {
            Matrix qt = QR(a);
            Matrix r = qt * a;
            Vector c = qt * b;
            Vector x = BackSub(r, c);
            return x;
        }

        /// <summary>
        /// Checks whether matrix m is positive definite by trying to form Cholesky decomposition
        /// </summary>
        /// <param name="m">Matrix m</param>
        /// <returns>True, if matrix has Cholesky decomposition (and therefore is positive definite)</returns>
        public static bool IsPositiveDefinite(Matrix m)
        {
            try
            {
                Cholesky(m);
            }
            catch(LinearException)
            {
                return false;
            }
            return true;
        }

        /// <summary>
        /// Cholesky decomposition of non-singular, positive definite matrix.
        /// Returns the lower triangular part of the decomposition.
        /// </summary>
        /// <param name="m">Non-singular matrix</param>
        /// <returns>Lower triangular matrix L of Cholesky decomposition</returns>
		public static Matrix Cholesky(Matrix m)
		{
			if (m.Rows != m.Columns) throw new LinearException("Matrix is not square!");

			int n = m.Rows;
			Matrix r = new Matrix(n);

			// Go through all rows, starting from the first
			for (int j = 0; j < n; j++)
			{
				// Calculate the sum of squares of matrix elements [j,0]...[j,k]
				// Substract the sum from the diagonal element
				double sum1 = m[j, j];
				for (int k = 0; k < j; k++)
				{
					double rjk = r[j, k];
					sum1 -= rjk * rjk;
				}

				// If division by zero or negative root appears on the main diagonal,
				// the matrix is not positive definite.
				if (sum1 <= 0) throw new LinearException("Matrix is not positive definite!");

				double diagonal = Math.Sqrt(sum1);
				r[j, j] = diagonal;

				// Update the elements of r below the current row
				for (int i = j + 1; i < n; i++)
				{
					double sum2 = m[i, j];
					for (int k = 0; k < j; k++)
					{
						sum2 -= r[i, k] * r[j, k];
					}
					r[i, j] = sum2 / diagonal;
				}
				// Proceed to the next row
			}

			// r is now a lower diagonal matrix
			return r;
		}

        /// <summary>
        /// QR decomposition of a matrix.
        /// Returns the transpose of matrix Q.
        /// Upper triangular matrix R can be obtained by calculating R=Q'A (Q' is the transpose of Q).
        /// 
        /// Linear systems can now be easily solved: Ax=b -> QRx=b -> Rx=Q'b.
        /// </summary>
        /// <param name="a">Matrix to decompose</param>
        /// <returns>Matrix Q'</returns>
        public static Matrix QR(Matrix a)
        {
            // Number of rows to reduce
            int t = (a.Rows < a.Columns) ? a.Rows - 1 : a.Columns;
            int n = a.Rows;

            // Start with NxN identity matrix (i.e. identity tranformation)
            Matrix qt = Matrix.Identity(n);

            // Go through rows
            for (int k = 0; k < t; k++)
            {
                // Take column from element [k,k] down
                Vector x = a.GetColumnVector(k);
                x = Vector.Range(x, k, x.Length - k);

                // Try to keep diagonal values near zero; sign(alpha) = -sign(x[0])
                double alpha = Vector.Norm(x);
                if (x[0] < 0) alpha = -alpha;

                // Calculate the normal vector defining the reflecting hyperplane
                Vector u = x - alpha * Vector.Unit(x.Length, 0);

                // Calculate the Householder transformation matrix and expand the HH matrix to a NxN matrix
                Matrix qi = Matrix.Expand(Householder(u), k);

                // Add the transformation to the previous ones
                qt = qi * qt;

                // Apply the tranformation to A: elements below [k,k] should now be reduced to 0
                a = qi * a;
            }

            // Result is
            // Q(t-1) * Q(t-2) * ... * Q(0) = [ Q'(0) * ... * Q'(t-1) ]' = Q'
            return qt;
        }
    }
}
