﻿/******************************************************************************
 *
 * Copyright (c) 2008 Turku PET Centre
 *
 * This program is free software; you can redistribute it and/or modify it under
 * the terms of the GNU General Public License as published by the Free Software
 * Foundation; either version 2 of the License, or (at your option) any later
 * version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
 * Place, Suite 330, Boston, MA 02111-1307 USA.
 *
 * Turku PET Centre hereby disclaims all copyright interest in the program.
 * Juhani Knuuti
 * Director, Professor
 * 
 * Turku PET Centre, Turku, Finland, http://www.turkupetcentre.fi/
 * 
 ******************************************************************************/
using System;
using System.Runtime.InteropServices;

namespace TPClib.Model
{
    /// <summary>
    /// We now consider fitting when the model depends nonlinearly on the set of M
    /// unknown parameters. We define chi^2 merit function and determine best-fit parameters
    /// by its minimization.With nonlinear dependences, however, the minimization must
    /// proceed iteratively. Given trial values for the parameters, we develop a procedure
    /// that improves the trial solution. The procedure is then repeated until chi^2 stops (or
    /// effectively stops) decreasing.
    /// </summary>
    [ClassInterface(ClassInterfaceType.AutoDual), ComSourceInterfacesAttribute(typeof(Ifile))]
    public class FitLMA : Optimization
    {
        /// <summary>
        /// A set of data x points
        /// </summary>
        private double[] x;

        /// <summary>
        /// A set of data y points
        /// </summary>
        private double[] y;

        /// <summary>
        /// Individual standard deviations,
        /// the measurement error, if not known, is set to 1.
        /// </summary>
        private double[] sigma;

        /// <summary>
        /// The number of fitted parameters.
        /// </summary>
        private int numfit;

        /// <summary>
        /// The original squared chi.
        /// </summary>
        private double prevChiSq;

        /// <summary>
        /// Chi squared, the best fitted value found.
        /// </summary>
        private double chisq;

        /// <summary>
        /// The next candidate for minimized parameter.
        /// </summary>
        private Vector paramTry;

        /// <summary>
        /// The difference between current parameters
        /// and the next candidate for minized parameter.
        /// </summary>
        private Vector paramDelta;

        /// <summary>
        /// the parameter vector to be minimized
        /// </summary>
        private Vector beta;

        /// <summary>
        /// For every iteration, this is returned.
        /// </summary>
        private Vector currentParams;

        /// <summary>
        /// A matrix with one column.
        /// There is as many rows than number of fitted parameters.
        /// </summary>
        private Matrix oneDa;

        /// <summary>
        /// state matrix
        /// </summary>
        private Matrix covariance;

        /// <summary>
        /// state matrix
        /// </summary>
        private Matrix curvature;

        /// <summary>
        /// index is true if parameter is fitted
        /// </summary>
        private bool[] isFitted;

        /// <summary>
        /// The number of true indeces.
        /// </summary>
        private static int CountTrues(bool[] table)
        {
            int c = 0;
            foreach (bool b in table) if (b == true) c++;
            return c;
        }

        /// <summary>
        /// Function to fit the data
        /// </summary>
        private PartialDerivate fittedFunction;

        /// <summary>
        /// if zero, the algorithm is converged.
        /// </summary>
        private double lambda;
        private double lambdaLimit = 1e-10;

        /// <summary>
        /// the algorithm is fully converged when
        /// this value is below inner limit.
        /// </summary>
        public double Convergence
        {
            get { return lambda; }
        }


        /// <summary>
        /// Inits FitLMA. This method is called from Parent class (optimization)
        /// Create() -method. 
        /// </summary>
        /// <param name="function">optimized function</param>
        /// <param name="info">Info of method (Not needed with NelderMean)</param>
        /// <param name="c">parameter ceiling values</param>
        /// <param name="f">parameter floor values</param>
        /// <param name="initial">initial parameter values</param>
        protected override void InitMethod(
            RealFunction function,
            OptimizationInfo info,
            Vector c,
            Vector f,
            Vector initial)
        {
            throw new NotImplementedException();
        }

        /// <summary>
        /// Constuct LMA and initialize algorithm data structures.
        /// If analytical partial derivates are known, this constructor is used.
        /// </summary>
        /// <param name="datax">Data x values</param>
        /// <param name="datay">Data y values</param>
        /// <param name="initial">initial target parameters</param>
        /// <param name="fittedFunc">Function used to fit in.</param>
        public FitLMA(double[] datax, double[] datay, double[] initial, PartialDerivate fittedFunc)
        {
            if (datax == null || datay == null || initial == null || fittedFunc == null)
                throw new OptimizationException("No null parameters allowed.");

            Init(datax, datay, null, initial, null, fittedFunc);
        }

        /// <summary>
        /// Constuct LMA and initialize algorithm data structures.
        /// If analytical partial derivates are not known, this constructor is used.
        /// </summary>
        /// <param name="datax">Data x values</param>
        /// <param name="datay">Data y values</param>
        /// <param name="initial">initial target parameters</param>
        /// <param name="yFit">Function used to fit in.</param>
        public FitLMA(double[] datax, double[] datay, double[] initial, ParameterFunction yFit)
        {
            if (datax == null || datay == null || initial == null || yFit == null)
                throw new OptimizationException("No null parameters allowed.");

            ApproximateDerivate app = delegate(ParameterFunction ff, double t, Vector a, int i)
            {
                Vector b = new Vector(a);
                b[i] = b[i] + 1e-10;
                return (ff(t, b) - ff(t, a)) / 1e-10;
            };

            PartialDerivate p = delegate(double x, Vector a, Vector dyda)
            {
                for (int i = 0; i < dyda.Length; i++)
                {
                    dyda[i] = app(yFit, x, a, i);
                }

                return yFit(x, a);
            };

            Init(datax, datay, null, initial, null, p);
        }

        private void Init(double[] datax, double[] datay, double[] dev,
            double[] initial, bool[] fitted, PartialDerivate fittedFunc)
        {
            if (dev == null)
            {
                dev = new double[datax.Length];
                for (int i = 0; i < dev.Length; i++)
                    dev[i] = 1;
            }
            if (fitted == null)
            {
                fitted = new bool[initial.Length];
                for (int i = 0; i < fitted.Length; i++)
                    fitted[i] = true;
            }

            if (datay.Length != datax.Length || dev.Length != datax.Length)
                throw new OptimizationException("Data x, y and deviation table must have same length.");
            if (fitted != null && initial.Length != fitted.Length)
                throw new OptimizationException("");

            this.x = datax;
            this.y = datay;
            this.sigma = dev;
            this.isFitted = fitted;
            this.currentParams = initial;
            
            base.InitialParams = currentParams;
            base.TargetFunction = delegate(Vector dump) {
                return this.chisq;
            };
            
            this.fittedFunction = fittedFunc;
            this.covariance = new Matrix(initial.Length); // NxN
            this.covariance.Fill(0);
            this.curvature = new Matrix(initial.Length); // NxN
            this.curvature.Fill(0.01);
            this.lambda = 0.001;
            this.beta = new Vector(initial.Length);
            this.paramDelta = new Vector(initial.Length);
            this.numfit = CountTrues(isFitted);
            // creates a matrix with one column.
            this.oneDa = new Matrix(numfit, 1);
            // Evaluates the linearized fitting matrix
            // curvature and vector beta
            // and also calculates the squared chi.
            this.chisq = Mrqcof(currentParams, ref curvature, ref beta);
            // Saves the originals
            this.prevChiSq = chisq;
            this.paramTry = new Vector(initial);
        }

        /// <summary>
        /// Each iteration step, the parameter vector beta is
        /// replaced by a new estimate beta + delta. To determine
        /// delta, target function f(beta + delta) is
        /// approximeted by their linearization
        /// </summary>
        /// <returns>the best parameters found so far and their chi squared.</returns>
        protected override Vector Step()
        {            
            // Alters linearized fitting matrix,
            // by augmenting diagonal elements.
            for (int j = 0; j < this.numfit; j++)
            {
                for (int k = 0; k < this.numfit; k++)
                {
                    covariance[j, k] = curvature[j, k];
                }
                covariance[j, j] = curvature[j, j] * (1.0 + lambda);

                // copies beta to Nx1 matrix
                oneDa[j, 0] = beta[j];
            }

            // Matrix solution.
            LinearEquations.GaussJordan(covariance, numfit, oneDa, 1);
            // the first column is the solution.
            paramDelta = oneDa.GetColumn(0);

            // Once converged, evaluate covarience matrix.
            if (lambda < lambdaLimit)
            {
                Matrix.CovarianceSort(covariance, covariance.Columns, isFitted);
                // Spread out alpha to its full size too.
                Matrix.CovarianceSort(curvature, covariance.Columns, isFitted);

                return currentParams;
            }

            // Did the trial succeed?
            for (int j = 0, l = 0; l < covariance.Columns; l++)
            {
                if (isFitted[l])
                {
                    paramTry[l] = currentParams[l] + paramDelta[j];
                    j++;
                }
            }

            // calculates the new value.
            chisq = Mrqcof(paramTry, ref covariance, ref paramDelta);

            // Success, accept the new solution.
            if (chisq < this.prevChiSq)
            {
                this.lambda *= 0.1;
                this.prevChiSq = chisq;
                curvature.Copy(covariance);
                beta = new Vector(paramDelta);
                currentParams = new Vector(paramTry);
            }
            // Failure, increase alambda and return.
            else
            {
                this.lambda *= 10.0;
                chisq = this.prevChiSq;
            }

            // beta is replaced by a new estimate beta + delta
            return currentParams;
        }

        /// <summary>
        /// Calculation of Gradient and Hessian.
        /// Used by to evaluate the linearized fitting matrix alpha,
        /// and vector beta as in (15.5.8) and calculate X^2.
        /// </summary>
        /// <remarks>
        /// PRE:
        ///     a is not changed.
        /// </remarks>
        /// <returns>the new chisq value</returns>
        private double Mrqcof(Vector a, ref Matrix alpha, ref Vector beta)
        {
            // initialize symmetric alpha, beta
            //  x1 x2 x3        x1 x2 x3
            //  x4 x5 x6  ==>   0  x5 x6
            //  x7 x8 x9        0  0  x9

            // for each row
            for (int row = 0; row < this.numfit; row++)
            {
                // for each column
                for (int col = 0; col < row; col++)
                {
                    alpha[row, col] = 0.0;
                }

                // initialize beta
                beta[row] = 0.0;
            }

            // the return value
            double result = 0.0;

            // derivaatta pisteessä x[i]
            Vector dyda = new Vector(a.Length);

            // Summation loop over all data.
            for (int i = 0; i < x.Length; i++)
            {
                // the function value
                double ymod = fittedFunction(x[i], a, dyda);

                // the nominator: sigma^-2
                double sig2i = 1.0 / (sigma[i] * sigma[i]);

                // the numerator
                double dy = y[i] - ymod;

                // Summation over all parameters
                for (int j = 0, l = 0; l < a.Length; l++)
                {
                    if (isFitted[l])
                    {
                        // 1. derivaatta
                        double wt = dyda[l] * sig2i;

                        for (int m = 0, k = 0; m < l; m++)
                        {
                            if (isFitted[m])
                            {
                                k++;
                                // 2. derivaatta:   15.5.8b
                                alpha[j, k] += wt * dyda[m];
                            }
                        }

                        // 15.5.8a :  lisää betan summaan
                        beta[j] += dy * wt;
                        j++;
                    }
                }

                // add to Chi^2
                result += dy * dy * sig2i;
            }

            // Fill in the symmetric side.
            //  x a b        x a b
            //  0 y c  ==>   a y c
            //  0 0 z        b c z
            for (int row = 1; row < this.numfit; row++)
            {
                for (int col = 0; col < row; col++)
                {
                    alpha[col, row] = alpha[row, col];
                }
            }

            return result;
        }
    }
}