/// @file mrp.c
/// @brief Median Root Prior image reconstruction from PET sinogram.
/// @details Based on the program mrprec (Jun 1997) written by Sakari Alenius
///  for Sun UNIX workstations. Reference:
///  Alenius S, Ruotsalainen U 'Bayesian image reconstruction for emission
///  tomography based on median root prior', EJNM, vol. 24 no. 3, Mar 1997.
/// @author Vesa Oikonen
///
/*****************************************************************************/
#include "libtpcrec.h"
#ifdef HAVE_OMP_H
#include <omp.h>
#endif
/*****************************************************************************/

/*****************************************************************************/
/** Median Root Prior (MRP) reconstruction using data in IMG struct.
    @sa imgFBP, mrp
    @return Returns 0 if ok.
 */
int imgMRP(
  /** Sinogram (input) data.
      Data must be normalization and attenuation corrected. */
  IMG *scn,
  /** Image (output) data; allocated here. */
  IMG *img,
  /** Image dimension (size, usually 128 or 256); must be an even number. */
  int imgDim,
  /** Zoom factor (for example 2.45 for the brain); 1=no zooming. */
  float zoom,
  /** Possible shifting in x dimension (mm). */
  float shiftX,
  /** Possible shifting in y dimension (mm). */
  float shiftY,
  /** Possible image rotation, -180 - +180 (in degrees). */
  float rotation,
  /** Nr of iterations, for example 150. */
  int maxIterNr,
  /** Number of iterations to skip before prior; usually 1. */
  int skipPriorNr,
  /** Beta, 0.01 - 0.9; usually 0.3 for emission, 0.9 for transmission. */
  float beta,
  /** Median filter mask dimension; 3 or 5 (9 or 21 pixels). */
  int maskDim,
  /** Number of Ordered Subset sets; 1, 2, 4, ... 128. */
  int osSetNr,
  /** Verbose level; if zero, then nothing is printed to stderr or stdout. */
  int verbose
) {
  if(verbose>1)
    printf("imgMRP(scn, img, %d, %g, %g, %g, %g, %d, %d, %g, %d, %d)\n",
           imgDim, zoom, shiftX, shiftY, rotation, maxIterNr, skipPriorNr, beta,
           maskDim, osSetNr);


  /* Check the arguments */
  if(scn->status!=IMG_STATUS_OCCUPIED) return(1);
  if(imgDim<2 || imgDim>4096 || imgDim%2) return(1);
  if(zoom<0.01 || zoom>1000.) return(1);
  if(scn->dimx<=1 || scn->dimx>16384) return(1);
  if(beta<0.0) return(1);
  if(maskDim!=3 && maskDim!=5) return(1);


  /*
   *  Allocate output image
   */
  if(verbose>1) printf("allocating memory for the image\n");
  imgEmpty(img); 
  if(imgAllocate(img, scn->dimz, imgDim, imgDim, scn->dimt)!=0) return(3);

  /* Set image "header" information */
  if(verbose>1) printf("setting image header\n");
  img->type=IMG_TYPE_IMAGE;
  img->unit=CUNIT_CPS; /* (cnts/sec) */
  img->scanStart=scn->scanStart;
  img->axialFOV=scn->axialFOV;
  img->transaxialFOV=scn->transaxialFOV;
  img->sizez=scn->sizez;
  strcpy(img->studyNr, scn->studyNr);
  img->sampleDistance=scn->sampleDistance;
  if(scn->sampleDistance<=0.0) {
    if(scn->dimx==281) { // GE Advance
      img->sizez=4.25;
      scn->sampleDistance=1.95730;
      scn->axialFOV=150.; scn->transaxialFOV=550.;
    } else { // ECAT 931
      img->sizez=6.75;
      scn->sampleDistance=3.12932;
      scn->axialFOV=108.; scn->transaxialFOV=600.829;
    }
  }
  float pixSize; /* note that pixSize is in cm in the ECAT image */
  pixSize=scn->sampleDistance*(float)scn->dimx/((float)imgDim*zoom);
  img->sizex=img->sizey=pixSize;
  int plane, frame;
  for(plane=0; plane<scn->dimz; plane++)
    img->planeNumber[plane]=scn->planeNumber[plane];
  strcpy(img->radiopharmaceutical, scn->radiopharmaceutical);
  img->isotopeHalflife=scn->isotopeHalflife;
  img->decayCorrection=IMG_DC_NONCORRECTED;
  for(frame=0; frame<scn->dimt; frame++) {
    img->start[frame]=scn->start[frame]; img->end[frame]=scn->end[frame];
    img->mid[frame]=0.5*(img->start[frame]+img->end[frame]);
  }
  img->isWeight=0;

  /*
   *  Preparations for reconstruction
   */

  /* Pre-compute the sine tables for back-projection (note the rotation!) */
  if(verbose>1) printf("computing sine tables for back-projection\n");
  float sinB[3*scn->dimy/2];
  float sinBrot[3*scn->dimy/2];
  recSinTables(scn->dimy, sinB, sinBrot, rotation);

  /* Set the backprojection zoom and inverse (globals) */
  float bpZoom, bpZoomInv;
  bpZoom=zoom*(float)imgDim/(float)scn->dimx; 
  bpZoomInv=1.0/bpZoom;
  if(verbose>2) printf("bpZoom=%g bpZoomInv=%g\n", bpZoom, bpZoomInv);

  /* Initialize variables used by back-projection */
  if(verbose>1) printf("initialize variables for back-projection\n");
  float offsX, offsY;
  int halfDim;
  halfDim=imgDim/2; 
  offsX=shiftX/pixSize; offsY=shiftY/pixSize;
  for(int i=0; i<3*(scn->dimy)/2; i++) {
    sinB[i]*=bpZoomInv; 
    sinBrot[i]*=bpZoomInv;
  }
  if(verbose>2) {
    printf("halfDim=%d offsX=%g offsY=%g\n", halfDim, offsX, offsY);
  }


  /*
   *  Reconstruct one matrix at a time
   */
  if(verbose>1) printf("reconstruct one matrix at a time...\n");
  int failed=0;
#pragma omp parallel for private(frame)
  for(plane=0; plane<scn->dimz; plane++) {
    for(frame=0; frame<scn->dimt; frame++) {
      if(failed) break;

      if(verbose>3) {
        printf("reconstructing plane %d frame %d\n", 
               scn->planeNumber[plane], frame+1);
        fflush(stdout);
      }
      int i, j, k, ret;

      /* Copy scan data into the array */
      float scnData[scn->dimx*scn->dimy];
      float imgData[img->dimx*img->dimy];
      //float *imgOrigin=imgData+imgDim*(halfDim-1)+halfDim;

      for(i=0, k=0; i<scn->dimy; i++) 
        for(j=0; j<scn->dimx; j++)
          scnData[k++]=scn->m[plane][i][j][frame];
      /* Initiate image buffer */
      for(i=0; i<imgDim*imgDim; i++) imgData[i]=0.0;

      /* Reconstruct */
      ret=mrp(scnData, scn->dimx, scn->dimy, maxIterNr, osSetNr, maskDim, zoom,
            beta, skipPriorNr, imgDim, imgData, verbose-3);
      if(ret!=0) {
        if(verbose>0) fprintf(stderr, "mrp() return value %d\n", ret);
        failed=ret; break;
      }

      /* Copy the image array to matrix data            */
      /* At the same time, correct for the frame length */
      float f;
      f=img->end[frame]-img->start[frame]; 
      if(f<1.0) f=1.0; 
      f=1.0/f;
      for(i=0, k=0; i<img->dimy; i++)
        for(j=0; j<img->dimx; j++)
          img->m[plane][i][j][frame] = imgData[k++]*f;
      if(verbose==1) {fprintf(stdout, "."); fflush(stdout);}
    } /* next frame */
  } /* next plane */
  if(verbose==1) {fprintf(stdout, "\n"); fflush(stdout);}
  if(failed) return(8);

  if(verbose>1) printf("imgMRP() done.\n");
  return(0);
}
/*****************************************************************************/

/*****************************************************************************/
/** Update an image according to a set of coefficients.

   Performs a pixel-by-pixel multiplication between the current image and
   the coefficients. The result is non-negative.
*/
void mrpUpdate(
  /** Pointer to an array of coefficients, of length n. */
  float *coef, 
  /** Pointer to an array containing source image data, of length n. */
  float *img, 
  /** Pointer to an array containing output image data, of length n. */
  float *oimg, 
  /** Array lengths */
  int n
) {
  for(int i=0; i<n; i++) {
    oimg[i]=coef[i]*img[i];
    if(oimg[i]<0.0) oimg[i]=0.0;
  }
}
/*****************************************************************************/

/*****************************************************************************/
/** Calculate correction factors for EM-ML reconstruction.
    These factors will we back-projected over the image in order to get 
    the ML-coefficients.

    Divides the measured projection (sinogram) by the re-projected ray sum.
    If the divisor is close to zero, the factor is set to 0. The factor is
    not allowed to exceed a limit or to be negative.

    @author Sakari Alenius
*/
void mrpProjectionCorrection(
  /** Measured sinogram data. */
  float *measured, 
  /** Projection data. */
  float *proj, 
  /** Correction matrix. */
  float *correct, 
  /** Number of OS sets. */
  int os_sets,
  /** Sinogram rays. */
  int rays,
  /** Sinogram views. */
  int views
) {
  for(int i=0; i<rays*views/os_sets; i++) {
    if(fabs(proj[i])>1.0E-40) {
      correct[i]=(float)os_sets*measured[i]/proj[i];
      if(correct[i]>100.) correct[i]=100.;
      else if(correct[i]<0.0) correct[i]=0.0;
    } else {
      correct[i]=0.0;
    }
  }
}
/*****************************************************************************/

/*****************************************************************************/
/** Median Root Prior (MRP) reconstruction of one 2D data matrix given
    as an array of floats. 

    @sa trmrp, reprojection
    @return Returns 0 if ok.
 */
int mrp(
  /** Pointer to float array containing rays*views sinogram values.
      Data must be normalization and attenuation corrected. */
  float *sinogram,
  /** Nr of rays (bins or columns) in sinogram data. */
  int rays,
  /** Nr of views (rows) in sinogram data. */
  int views,
  /** Nr of iterations. */
  int iter,
  /** Length of ordered subset process order array; 1, 2, 4, ... 128. */
  int os_sets,
  /** Mask dimension; 3 or 5 (9 or 21 pixels). */
  int maskdim,
  /** Reconstruction zoom. */
  float zoom,
  /** Beta. */
  float beta,
  /** Number of iteration before prior; usually 1. */
  int skip_prior,
  /** Image x and y dimensions; must be an even number, preferably the same as
      number of rays, but there is no reason for it to be any larger than that.
   */
  int dim,
  /** Pointer to pre-allocated image data; size must be at least dim*dim. */
  float *image,
  /** Verbose level; if zero, then nothing is printed to stderr or stdout. */
  int verbose
) {
  if(verbose>0)
    printf("mrp(%d, %d, %d, %d, %d, %g, %g, %d, %d)\n",
           rays, views, iter, os_sets, maskdim, zoom, beta, skip_prior, dim);
  if(sinogram==NULL || image==NULL) return(1);
  if(rays<2 || views<2 || iter<1 || os_sets<1 || dim<2) return(1);
  if(dim%2 || zoom<0.05) return(2);
  if(maskdim!=3 && maskdim!=5) return(2);

  int i, k;
  int imgsize=dim*dim;
  int halfdim=dim/2;
  int views_in_set=views/os_sets;

  /* Set scale */
  int recrays;
  float scale, bp_zoom;
  recrays=(int)((float)dim*zoom);
  bp_zoom=zoom*(float)dim/(float)recrays;
  scale=((float)recrays/(float)rays)*bp_zoom*bp_zoom/(float)views;
  if(verbose>1) {
    printf("  recrays := %d\n", recrays);
    printf("  bp_zoom := %g\n", bp_zoom);
    printf("  scale := %g\n", scale);
  }

  /* Make the Ordered Subset process order (bit-reversed sequence) */
  int seq[os_sets]; set_os_set(os_sets, seq);
  if(verbose>2) {
    printf("os_sets :=");
    for(i=0; i<os_sets; i++) printf(" %d", seq[i]);
    printf("\n"); fflush(stdout);
  }

  /* Arrange the sinogram and interpolate so that
     bin width equals pixel width */
  float sino[recrays*views];
  {
    float scnset[rays*views];
    /* arrange */
    for(int s=0; s<os_sets; s++) {
      for(int j=0; j<views/os_sets; j++) {
        memcpy((char*)(scnset + s*rays*views/os_sets + j*rays),
               (char*)(sinogram + j*rays*os_sets + s*rays), rays*sizeof(float));
      }
    }
    /* interpolate */
    recInterpolateSinogram(scnset, sino, rays, recrays, views);
  }

  /* Get some statistics from the sinogram matrix */
  int scnsize=recrays*views;
  int nonzeroNr;
  float counts;
  nonzeroNr=recGetStatistics(sino, scnsize, &counts, NULL, NULL, 1);
  if(verbose>1) {
    printf("  total_counts := %g\n", counts);
    printf("  non-zeroes := %d/%d\n", nonzeroNr, scnsize);
  }


  /* Make an initial image: an uniform disk enclosed by rays with a value
     matching the total count */
  if(verbose>1) printf("creating initial %dx%d image\n", dim, dim);
  float current_img[imgsize];
  for(i=0; i<imgsize; i++) image[i]=current_img[i]=0.0;
  float init;
  k=0;
  init=counts/(M_PI*(float)((halfdim-1)*(halfdim-1)));
  for(int j=halfdim-1; j>=-halfdim; j--) {
    for(i=-halfdim; i<halfdim; i++) {
      if((int)hypot((double)i, (double)j) < halfdim-1)
        current_img[k]=init;
      k++;
    }
  }


  /* Pre-compute the sine tables for back-projection */
  if(verbose>1) printf("computing sine table\n");
  float sinB[3*views/2];
  recSinTables(views, sinB, NULL, 0.0);
  for(i=0; i<3*views/2; i++) sinB[i]/=bp_zoom;

  /* Iterations */
  if(verbose>1) {printf("iterations\n"); fflush(stdout);}
  float current_proj[recrays*views_in_set+1];
  float correction[recrays*views/os_sets];
  float medcoefs[imgsize], mlcoefs[imgsize];
  int s, itercount=1, view;
  do {
    if(verbose>3) {printf("  iteration %d\n", itercount); fflush(stdout);}

    for(s=0; s<os_sets; s++) {
      if(verbose>4) {
        printf("    os_set %d; seq=%d\n", 1+s, seq[s]); 
        fflush(stdout);
      }
      /* Image reprojection */
      for(i=0; i<recrays*views_in_set; i++) current_proj[i]=0.0;
      for(i=0; i<views_in_set; i++) {
        view=seq[s]+i*os_sets;
        viewReprojection(current_img, current_proj+i*recrays, view, dim, 
          views, recrays, sinB, sinB, 0.0, 0.0, bp_zoom);
      }
      /* Calculate correction = measured / re-projected */
      mrpProjectionCorrection(sino+seq[s]*recrays*views_in_set, current_proj,
                              correction, os_sets, recrays, views);
      /* Make the ML coefficients */
      for(i=0; i<imgsize; i++) mlcoefs[i]=0.0;
      for(i=0; i<views_in_set; i++)
        viewBackprojection(correction+i*recrays, mlcoefs, dim, 
              seq[s]+i*os_sets, views, recrays, sinB, sinB, 0.0, 0.0, bp_zoom);
      /* Make the prior coefficients */
      if(skip_prior<=0 && beta>0.0) {
        if(verbose>3) {printf("    applying prior\n"); fflush(stdout);}
        float maxv, maxm;
        fMinMaxFin(current_img, imgsize, NULL, &maxv);
        do_prior(current_img, beta, medcoefs, dim, 
                   1.0E-08*maxv, maskdim, &maxm);
        if(verbose>3) {
          printf("      max value in current image := %g\n", maxv);
          printf("      max median coefficient := %g\n", maxm);
        }
        /* Adjust ML coefficients */
        mrpUpdate(medcoefs, mlcoefs, mlcoefs, imgsize);
      }

      /* Calculate next image */
      mrpUpdate(mlcoefs, current_img, current_img, imgsize);

    } // next set
    itercount++; skip_prior--;
    if(verbose>3) {printf("  -> next iteration maybe\n"); fflush(stdout);}
  } while(itercount<iter);
  if(verbose>2) {printf("  iterations done.\n"); fflush(stdout);}

  for(i=0; i<imgsize; i++) image[i]=current_img[i]*=scale;

  if(verbose>1) {printf("mrp() done.\n"); fflush(stdout);}
  return(0);
} // mrp()
/*****************************************************************************/

/*****************************************************************************/
