﻿/******************************************************************************
 *
 * Copyright (c) 2008 Turku PET Centre
 *
 * This program is free software; you can redistribute it and/or modify it under
 * the terms of the GNU General Public License as published by the Free Software
 * Foundation; either version 2 of the License, or (at your option) any later
 * version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
 * Place, Suite 330, Boston, MA 02111-1307 USA.
 *
 * Turku PET Centre hereby disclaims all copyright interest in the program.
 * Juhani Knuuti
 * Director, Professor
 * 
 * Turku PET Centre, Turku, Finland, http://www.turkupetcentre.fi/
 * 
 ******************************************************************************/
using System;
using System.Collections.Generic;

namespace TPClib.Model
{
    /// <summary>
    /// Collection of functions related to solving linear equations
    /// </summary>
    public static class Linear
    {
        /// <summary>
        /// Gauss-Jordan method for general, non-square matrices.
        /// Solves the linear equation system Ax=b --> A'Ax=A'b
        /// (A' is the transpose of A)
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector GaussJordan(Matrix a, Vector b)
        {
            Matrix at = a.Transpose();
            Matrix r = at * a;
            Matrix c = (Matrix)(at * b);
            GaussJordan(r, r.Rows, c, 1);
            return c.GetColumn(0);
        }

        /// <summary>
        /// Solves the systems of linear equations AX=B
        /// with Gauss-Jordan elimination (with pivoting).
        /// The solution matrix X overwrites B.
        /// </summary>
        /// <param name="A">Coefficient matrix</param>
        /// <param name="n">The rows from the up.</param>
        /// <param name="B">The result vectors, one vector per column</param>
        /// <param name="m">The columns from the left.</param>
        public static void GaussJordan(Matrix A, int n, Matrix B, int m)
        {
            if (n < A.Columns) throw new LinearException("System is underdetermined");
            if (n != B.Rows) throw new LinearException("Dimension mismatch");
            if (m > B.Columns || m <= 0) throw new LinearException("Dimension mismatch");

            int icol = -1, irow = -1;
            double big, dum, pivinv, temp;

            // These three are used for bookkeeping on the pivoting.
            int[] indxc = new int[n];
            int[] indxr = new int[n];
            int[] ipiv = new int[n];

            // is initialized to zero
            for (int i = 0; i < n; i++)
            {
                ipiv[i] = 0;
            }

            // This is the main loop over
            // the columns to be reduced.
            for (int i = 0; i < n; i++)
            {
                big = 0.0;

                // This is the outer loop of the search for a pivot element.
                for (int j = 0; j < n; j++)
                {
                    if (ipiv[j] != 1)
                    {
                        for (int k = 0; k < n; k++)
                        {
                            if (ipiv[k] == 0)
                            {
                                double abs = Math.Abs(A[j, k]);
                                if (abs >= big)
                                {
                                    big = abs;
                                    irow = j;
                                    icol = k;
                                }
                            }
                            else if (ipiv[k] > 1)
                            {
                                throw new LinearException("Singular Matrix");
                            }
                        }
                    }
                }

                ipiv[icol]++;

                // We now have ...

                if (irow != icol)
                {
                    for (int l = 0; l < n; l++)
                    {
                        temp = A[irow, l];
                        A[irow, l] = A[icol, l];
                        A[icol, l] = temp;
                    }
                    for (int l = 0; l < m; l++)
                    {
                        temp = B[irow, l];
                        B[irow, l] = B[icol, l];
                        B[icol, l] = temp;
                    }
                }

                // We are now ready to divide the pivot row by the
                // pivot element, located at irow and icol.
                indxr[i] = irow;
                indxc[i] = icol;

                if (A[icol, icol] == 0.0)
                {
                    throw new Exception("gaussj: Singular Matrix-2");
                }

                pivinv = 1.0 / A[icol, icol];

                A[icol, icol] = 1.0;

                for (int l = 0; l < n; l++)
                {
                    A[icol, l] *= pivinv;
                }

                for (int l = 0; l < m; l++)
                {
                    B[icol, l] *= pivinv;
                }

                // Next, we reduce the rows...
                //...except for the pivot one, of course.
                for (int ll = 0; ll < n; ll++)
                {
                    if (ll != icol)
                    {
                        dum = A[ll, icol];
                        A[ll, icol] = 0.0;

                        for (int l = 0; l < n; l++)
                        {
                            A[ll, l] -= A[icol, l] * dum;
                        }
                        for (int l = 0; l < m; l++)
                        {
                            B[ll, l] -= B[icol, l] * dum;
                        }
                    }
                }

            }

            // This is the end of the main loop over
            // columns of the reduction. It only remains to unscramble
            // the solution in view of the column interchanges.
            // We do this by interchanging pairs of
            // columns in the reverse order that
            // the permutation was built up.
            for (int l = n - 1; l >= 0; l--)
            {
                if (indxr[l] != indxc[l])
                {
                    for (int k = 0; k < n; k++)
                    {
                        temp = A[k, indxr[l]];
                        A[k, indxr[l]] = A[k, indxc[l]];
                        A[k, indxc[l]] = temp;
                    }
                }
            }

            // And we are done.
        }

        /// <summary>
        /// Solves the linear system Ax=b, where A is a upper triangular matrix.
        /// If A is m*n matrix with n>m, gives the solution where the last n-m
        /// variables are 0.
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector BackSub(Matrix a, Vector b)
        {
            // If A and b are of different dimensions, throw exception
            if (b.Dim != a.Rows)
                throw new LinearException("Dimension mismatch");

            int n = a.Columns;
            Vector x = new Vector(n);

            // If system is underdetermined, i.e. m = a.Rows < a.Columns = n,
            // set the last n - m variables to zero and solve the remaining system.
            if (a.Columns > a.Rows) n = a.Rows;

            // Solve the system, starting from bottom. If values of zero or near zero (causing
            // NaN or Infinity) are encountered, set those variables to zero in the solution
            for (int i = n - 1; i >= 0; i--)
            {
                double s = Vector.Dot(Vector.Range(a.GetRow(i), i + 1, n - 1 - i), Vector.Range(x, i + 1, n - 1 - i));
                x[i] = (b[i] - s) / a[i, i];
                if (Double.IsInfinity(x[i]) || Double.IsNaN(x[i])) x[i] = 0;
            }

            return x;
        }

        /// <summary>
        /// Solves the linear system Ax=b, where A is a lower triangular matrix.
        /// If A is m*n matrix with n>m, gives the solution where the last n-m
        /// variables are 0.
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector ForwardSub(Matrix a, Vector b)
        {
            // If A and b are of different dimensions, throw exception
            if (b.Dim != a.Rows)
                throw new LinearException("Dimension mismatch");

            int n = a.Columns;
            Vector x = new Vector(n);

            // If system is underdetermined, i.e. m = a.Rows < a.Columns = n,
            // set the last n - m variables to zero and solve the remaining system.
            if (a.Columns > a.Rows) n = a.Rows;

            // Solve the system, starting from top. If values of zero or near zero (causing
            // NaN or Infinity) are encountered, set those variables to zero in the solution
            for (int i = 0; i < n; i++)
            {
                double s = Vector.Dot(Vector.Range(a.GetRow(i), 0, i), Vector.Range(x, 0, i));
                x[i] = (b[i] - s) / a[i, i];
                if (Double.IsInfinity(x[i]) || Double.IsNaN(x[i])) x[i] = 0;
            }

            return x;
        }

        /// <summary>
        /// Solves a general system of linear equations with Cholesky decomposition.
        /// A must be a MxN matrix with M >= N and of rank N.
        /// 
        /// If this is the case, A'A (where A' is the transpose of A) is symmetric and
        /// positive definite, and we can solve Ax=b by solving the equivalent system
        /// A'Ax=A'b with Cholesky decomposition (see CholeskySolve).
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector CholeskySolveGen(Matrix a, Vector b)
        {
            Matrix at = a.Transpose();
            return CholeskySolve(at * a, at * b);
        }

        /// <summary>
        /// Solves a linear equation system Ax=b, if A is symmetric and positive definite.
        /// (The normal equations of linear least squares problem are of this form)
        /// 
        /// Decomposing A, which is symmetric and positive definite, we get LL'x=b.
        /// Mark z=L'x. Now solve Lz=b, where L is lower triangular.
        /// Then solve L'x=z, where L' is upper triangular to get solution x.
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector CholeskySolve(Matrix a, Vector b)
        {
            Matrix r = Cholesky(a);
            Matrix rt = r.Transpose();
            Vector z = ForwardSub(r, b);
            Vector x = BackSub(rt, z);

            return x;
        }

        /// <summary>
        /// Householder transformation matrix
        /// </summary>
        /// <param name="x">Normal vector of the reflecting plane</param>
        /// <returns>Householder transformation matrix</returns>
        public static Matrix Householder(Vector x)
        {
            int n = x.Dim;
            double norm = Vector.Norm(x);
            Matrix q = Matrix.Identity(n);

            if (norm > 0.0)
            {
                Vector u = x / Vector.Norm(x);

                Matrix v = (Matrix)u;
                Matrix vt = v.Transpose();

                q = q - (v * vt) * 2;
            }
            return q;
        }

        /// <summary>
        /// Solves a system of linear equations with QR decomposition.
        /// 
        /// (Transpose of A is marked with A')
        /// From Ax=b decomposition gives QRx=b -> Rx=Q'b.
        /// R is upper triangular, so this is easily solved by backsubstitution.
        /// </summary>
        /// <param name="a">Matrix A</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector QRSolve(Matrix a, Vector b)
        {
            Matrix qt = QR(a);
            Matrix r = qt * a;
            Vector c = qt * b;
            Vector x = BackSub(r, c);
            return x;
        }

        /// <summary>
        /// Checks whether matrix m is positive definite by trying to form Cholesky decomposition
        /// </summary>
        /// <param name="m">Matrix m</param>
        /// <returns>True, if matrix has Cholesky decomposition (and therefore is positive definite)</returns>
        public static bool IsPositiveDefinite(Matrix m)
        {
            try
            {
                Cholesky(m);
            }
            catch
            {
                return false;
            }
            return true;
        }

        /// <summary>
        /// Cholesky decomposition of non-singular, positive definite matrix.
        /// Returns the lower triangular part of the decomposition.
        /// </summary>
        /// <param name="m">Non-singular matrix</param>
        /// <returns>Lower triangular matrix L of Cholesky decomposition</returns>
        public static Matrix Cholesky(Matrix m)
        {
            if (m.Rows != m.Columns) throw new LinearException("Matrix is not square!");

            int n = m.Rows;
            Matrix r = new Matrix(n);

            // Go through all rows, starting from the first
            for (int j = 0; j < n; j++)
            {
                // Calculate the sum of squares of matrix elements [j,0]...[j,k]
                double sum1 = 0;
                for (int k = 0; k < j; k++)
                {
                    double rjk = r[j, k];
                    sum1 += rjk * rjk;
                }

                // If division by zero or negative root appears on the main diagonal,
                // the matrix is not positive definite.
                if (m[j, j] <= sum1) throw new LinearException("Matrix is not positive definite!");

                r[j, j] = Math.Sqrt(m[j, j] - sum1);

                // Update the elements of r below the current row
                for (int i = j + 1; i < n; i++)
                {
                    double sum2 = 0;
                    for (int k = 0; k < j; k++)
                    {
                        sum2 += r[i, k] * r[j, k];
                    }
                    r[i, j] = (m[i, j] - sum2) / r[j, j];
                }
                // Proceed to the next row
            }

            // r is now a lower diagonal matrix
            return r;
        }

        /// <summary>
        /// QR decomposition of a matrix.
        /// Returns the transpose of matrix Q.
        /// Upper triangular matrix R can be obtained by calculating R=Q'A (Q' is the transpose of Q).
        /// 
        /// Linear systems can now be easily solved: Ax=b -> QRx=b -> Rx=Q'b.
        /// </summary>
        /// <param name="a">Matrix to decompose</param>
        /// <returns>Matrix Q'</returns>
        public static Matrix QR(Matrix a)
        {
            // Number of rows to reduce
            int t = (a.Rows < a.Columns) ? a.Rows - 1 : a.Columns;
            int n = a.Rows;

            // Start with NxN identity matrix (i.e. identity tranformation)
            Matrix qt = Matrix.Identity(n);

            // Go through rows
            for (int k = 0; k < t; k++)
            {
                // Take column from element [k,k] down
                Vector x = a.GetColumn(k);
                x = Vector.Range(x, k, x.Length - k);

                // Try to keep diagonal values near zero; sign(alpha) = -sign(x[0])
                double alpha = Vector.Norm(x);
                if (x[0] < 0) alpha = -alpha;

                // Calculate the normal vector defining the reflecting hyperplane
                Vector u = x - alpha * Vector.Unit(x.Dim, 0);

                // Calculate the Householder transformation matrix and expand HH matrix to NxN matrix
                Matrix qi = Matrix.Expand(Householder(u), k, k);

                // Add the transformation to the previous ones
                qt = qi * qt;

                // Apply the tranformation to A: elements below [k,k] should now be reduced to 0
                a = qi * a;
            }

            // Result is
            // Q(t-1) * Q(t-2) * ... * Q(0) = [ Q'(0) * ... * Q'(t-1) ]' = Q'
            return qt;
        }
/*
        /// <summary>
        /// Non-Negative Linear Least Squares (NNLS) solver.
        /// </summary>
        /// <param name="a">Matrix a</param>
        /// <param name="b">Constant vector b</param>
        /// <returns>Solution vector x</returns>
        public static Vector NNLS(Matrix a, Vector b)
        {
            throw new NotImplementedException();

            // If A and b are of different dimensions, throw exception
            if (b.Dim != a.Rows)
                throw new LinearException("Dimension mismatch");
            // If system is underdetermined, throw an exception
            if (a.Rows < a.Columns)
                throw new LinearException("Underdetermined system, not solvable");

            int n = a.Columns;
            // x is the estimated solution
            Vector x = new Vector(n);
            Vector w;

            // P is the set of solved variables, Z of the unsolved
            List<int> p = new List<int>();
            List<int> z = new List<int>();

            // Add all indices of x to z
            for (int i = 0; i < x.Length; i++)
            {
                z.Add(i);
            }

            while (z.Count != 0)
            {
                // Find the gradient at the current estimation
                w = a.Transpose() * (b - a * x);

                // Check, if we have reached a solution:
                bool solution = true;
                double maxw = Double.MinValue;
                int maxindex = 0;

                // Go through indices in Z. If they are all < 0, x is solution.
                // Otherwise, find the maximal element in w.
                foreach (int c in z)
                {
                    // If at least one element of w, with index in z,
                    // has a positive value, solution is not yet found
                    if (w[c] > 0)
                    {
                        solution = false;
                        if (w[c] > maxw)
                        {
                            maxw = w[c];
                            maxindex = c;
                        }
                    }
                }

                // If we have a solution, stop
                if (solution) break;

                // No solution yet. Move the maximal element to P.
                p.Add(maxindex);
                z.Remove(maxindex);

                // Calculate matrix ap with zeros on all rows with index in z.
                Matrix ap = new Matrix(a);
                for (int i = 0; i < n; i++)
                {
                    if (z.Contains(i))
                    {
                        for (int j = 0; j < ap.Rows; j++)
                        {
                            ap[j, i] = 0;
                        }
                    }
                }
                // Solve the equations APy=b
                Vector y = QRSolve(ap, b);

                // We get solutions for all indices in p,
                // others are set to 0
                foreach (int c in z)
                {
                    y[c] = 0.0;
                }

                //
                // Inner loop.
                //
                while (Vector.Min(y) < 0.0)
                {
                    // Otherwise, calculate a new x, with at least one (and possible more)
                    // zero elements. These are moved to p.
                    double alpha = Double.MaxValue;
                    int alpha_index = -1;
                    for (int i = 0; i < y.Length; i++)
                    {
                        if (y[i] < 0.0)
                        {
                            double r = x[i] / (x[i] - y[i]);
                            if (r < alpha)
                            {
                                alpha = r;
                                alpha_index = i;
                            }
                        }
                    }

                    x = x + alpha * (y - x);

                    foreach (int c in p)
                    {
                        if (x[c] == 0.0)
                        {
                            z.Add(c);
                        }
                    }
                    foreach (int c in z)
                    {
                        p.Remove(c);
                    }
                    

                    // Calculate matrix ap with zeros on all rows with index in z.
                    ap = new Matrix(a);
                    for (int i = 0; i < n; i++)
                    {
                        if (z.Contains(i))
                        {
                            for (int j = 0; j < ap.Rows; j++)
                            {
                                ap[j, i] = 0;
                            }
                        }
                    }
                    // Solve the equations APy=b
                    y = QRSolve(ap, b);

                    // We get solutions for all indices in p,
                    // others are set to 0
                    foreach (int c in z)
                    {
                        y[c] = 0;
                    }

                    // New trial solution calculated, go through the inner loop again.
                } // End of inner loop
                x = y;
            } // End of outer loop

            // Finished, return solution
            return x;
        }
 */
    }
}
