/// @file trmrp.c
/// @brief Transmission image reconstruction using MRP.
/// @remark Based on the program fbprec (Feb 1998) written by Sakari Alenius
///  for Sun UNIX workstations.
/// @author Vesa Oikonen
///
/*****************************************************************************/
#include "libtpcrec.h"
/*****************************************************************************/
#ifdef HAVE_OMP_H
#include <omp.h>
#endif
/*****************************************************************************/

/*****************************************************************************/
/** Median Root Prior (MRP) reconstruction of one 2D data matrix given as an array of floats. 

    @sa fbp, reprojection
    @return Returns 0 if ok.
 */
int trmrp(
  /** Float array containing rays*views blank sinogram values.
      Data must be normalization-corrected. */
  float *bla,
  /** Float array containing rays*views transmission sinogram values.
      Data must be normalization-corrected. */
  float *tra,
  /** Image x and y dimensions. */
  int dim,
  /** Pointer to pre-allocated image data; size must be at least dim*dim;
      log-transformed attenuation correction factors will be written in here. */
  float *image,
  /** Nr of iterations. */
  int iter,
  /** Length of ordered subset process order array; 1, 2, 4, ... 128. */
  int os_sets,
  /** Nr of rays (bins or columns) in sinogram data. */
  int rays,
  /** Nr of views (rows) in sinogram data. */
  int views,
  /** Mask dimension; 3 or 5 (9 or 21 pixels). */
  int maskdim,
  /** Reconstruction zoom. */
  float zoom,
  /** Beta. */
  float beta,
  /** Axial field-of-view in mm (found in transmission mainheader in cm). */
  float axial_fov,
  /** Sample distance in mm (found in transmission subheader in cm). */
  float sample_distance,
  /** Number of iteration before prior; usually 1. */
  int skip_prior,
  /** Use OSL-type (0 or 1). */
  int osl,
  /** Possible shifting in x dimension (mm). */
  float shiftX,
  /** Possible shifting in y dimension (mm). */
  float shiftY,
  /** Possible image rotation, -180 - +180 (in degrees). */
  float rotation,
  /** Verbose level; if zero, then nothing is printed to stderr or stdout. */
  int verbose
) {
  if(verbose>0) printf("trmrp()\n");

  if(bla==NULL || tra==NULL || image==NULL) return(1);
  if(rays<2 || views<2 || dim<2 || iter<1 || os_sets<1) return(1);
  if(maskdim!=3 && maskdim!=5) return(1);
  if(zoom<0.05) return(1);
  //if(dim%2) return(2);


  /* Set scale */
  int recrays;
  float scale, bp_zoom;
  recrays=(int)((float)dim*zoom);
  bp_zoom=zoom*(float)dim/(float)recrays;
  scale=((float)recrays/(float)rays)*bp_zoom*bp_zoom/(float)views;
  if(verbose>1) {
    printf("  recrays := %d\n", recrays);
    printf("  bp_zoom := %g\n", bp_zoom);
    printf("  scale := %g\n", scale);
  }

  int views_in_set;
  views_in_set=views/os_sets;
  if(verbose>1) printf("  views_in_set := %d\n", views_in_set);

  /* Make the Ordered Subset process order (bit-reversed sequence) */
  int seq[os_sets]; set_os_set(os_sets, seq);
  if(verbose>2) {
    printf("os_sets :=");
    for(int i=0; i<os_sets; i++) printf(" %d", seq[i]);
    printf("\n");
  }
  /* Arrange transmission and blank sinograms, and interpolate so that
     bin width equals pixel width */
  float recbla[recrays*views], rectra[recrays*views];
  {
    float blaset[rays*views], traset[rays*views];
    /* arrange */
    for(int s=0; s<os_sets; s++) {
      for(int j=0; j<views_in_set; j++) {
        memcpy((char*)(blaset + s*rays*views_in_set + j*rays),
               (char*)(bla + j*rays*os_sets + s*rays), rays*sizeof(float));
        memcpy((char*)(traset + s*rays*views_in_set + j*rays),
               (char*)(tra + j*rays*os_sets + s*rays), rays*sizeof(float));
      }
    }
    /* interpolate */
    recInterpolateSinogram(blaset, recbla, rays, recrays, views);
    recInterpolateSinogram(traset, rectra, rays, recrays, views);
  }


  /* Sum of blank and transmission */
  float bsum=0.0, tsum=0.0;
  for(int i=0; i<recrays*views_in_set; i++) bsum+=recbla[i];
  for(int i=0; i<recrays*views_in_set; i++) tsum+=rectra[i];
  if(verbose>2) {
    printf("  blank_sum := %g\n", bsum);
    printf("  transmission_sum := %g\n", tsum);
  }

  /* Make an initial image: an uniform disk enclosed by rays with a value
     matching the total count */
  if(verbose>1) printf("creating initial %dx%d image\n", dim, dim);
  int imgsize=dim*dim;
  float current_img[imgsize];
  for(int i=0; i<imgsize; i++) image[i]=0.0;
  float init;
  init=-logf(tsum/bsum)*sample_distance/axial_fov;
  if(verbose>2) printf("  init := %g\n", init);

  for(int k=0; k<imgsize; k++) current_img[k]=0.0;
  {
    int j, k=0;
    for(j=dim/2-1; j>=-dim/2; j--) {
      for(int i=-dim/2; i<dim/2; i++) {
        if((int)hypot((double)i, (double)j) < dim/2-1) current_img[k]=init;
        k++;
      }
    }
  }

  /* Pre-compute the sine tables for back-projection */
  if(verbose>1) printf("computing sine table\n");
  float sinB[3*views/2], sinBrot[3*views/2];
  recSinTables(views, sinB, sinBrot, rotation);
  for(int i=0; i<3*views/2; i++) sinB[i]/=bp_zoom;
  for(int i=0; i<3*views/2; i++) sinBrot[i]/=bp_zoom;

  /* Calculate pixel size */
  float pixsize;
  pixsize=sample_distance*(float)rays/(zoom*(float)dim);
  if(verbose>2) printf("  pixsize := %g\n", pixsize);
  /* ... and convert shifts from mm to pixels */
  shiftX/=pixsize;
  shiftY/=pixsize;


  /* Iterations */
  if(verbose>1) {printf("iterations\n"); fflush(stdout);}
  float muproj[recrays*views_in_set+1];
  float atnproj[recrays*views_in_set];
  float projdiff[recrays*views_in_set];
  float numerator[imgsize], denominator[imgsize];
  float med_img[imgsize], oslcoefs[imgsize];
  int itercount=1;
  do {
    if(verbose>3) {printf("  iteration %d\n", itercount); fflush(stdout);}

    if(verbose>1) {
      float mi, ma;
      fMinMaxFin(current_img, imgsize, &mi, &ma);
      printf("    min=%g max=%g iter=%i\n", mi, ma, itercount);
      for(int i=0; i<imgsize; i++) 
        if(!isfinite(current_img[i])) {
          printf("  inf in current image! index=%d, iter=%d\n", i, itercount);
          break;
        }
    }

    for(int s=0; s<os_sets; s++) {
      if(verbose>4) {printf("    os_set %d; seq=%d\n", 1+s, seq[s]); fflush(stdout);}
      /* Image reprojection */
      for(int i=0; i<recrays*views_in_set; i++) muproj[i]=0.0;
#pragma omp parallel for
      for(int i=0; i<views_in_set; i++) {
        int view=seq[s]+i*os_sets;
        if(verbose>8 && (i==0 || i==views_in_set-1)) printf("      reprojecting view %d\n", view);
        viewReprojection(current_img, muproj+i*recrays, view, dim, 
          views, recrays, sinB, sinBrot, shiftX, shiftY, bp_zoom);
        //re_proj(current_img, muproj+i*recrays, seq[s]+i*sets, dim, views);
      }

      /* Calculate correction = measured / re-projected */
#pragma omp parallel for
      for(int i=0; i<recrays*views_in_set; i++) {
        float mup=muproj[i]*scale;
        float est_count=expf(-mup)*recbla[i];
        atnproj[i]=est_count*mup;
        projdiff[i]=est_count - rectra[i];
      }
      /* Make numerator and denominator images */
      for(int i=0; i<imgsize; i++) numerator[i]=0.0;
      for(int i=0; i<imgsize; i++) denominator[i]=0.0;
#pragma omp parallel for
      for(int i=0; i<views_in_set; i++) {
        int view=seq[s]+i*os_sets;
        viewBackprojection(projdiff+i*recrays, numerator, dim, 
              view, views, recrays, sinB, sinBrot, shiftX, shiftY, bp_zoom);
        viewBackprojection(atnproj+i*recrays, denominator, dim, 
              view, views, recrays, sinB, sinBrot, shiftX, shiftY, bp_zoom);
      }
      if(verbose>4) {
        float mi, ma;
        fMinMaxFin(numerator, imgsize, &mi, &ma);
        printf("    numerator_range := %g - %g\n", mi, ma);
        fMinMaxFin(denominator, imgsize, &mi, &ma);
        printf("    denominator_range := %g - %g\n", mi, ma);
      }

      /* Apply the prior */
      if(skip_prior<=0 && beta>0.0) {
        if(verbose>3) {printf("    applying prior\n"); fflush(stdout);}
        if(osl) {
          float maxv, maxm;
          fMinMaxFin(current_img, imgsize, NULL, &maxv);
          do_prior(current_img, beta, oslcoefs, dim, 1.0E-08*maxv, maskdim, &maxm);
          if(verbose>6) {
            printf("      max value in current image := %g\n", maxv);
            printf("      max median coefficient := %g\n", maxm);
          }
        } else {
          float *iptr, *mptr;
          if(maskdim==3) {
            iptr=current_img+dim+1;
            mptr=med_img+dim+1;
            for(int i=2*(dim+1); i<imgsize; i++) *mptr++=med9(iptr++, dim);
          } else if(maskdim==5) {
            iptr=current_img+2*dim+2; 
            mptr=med_img+2*dim+2;
            for(int i=2*(2*dim+2); i<imgsize; i++) *mptr++=med21(iptr++, dim);
          }
          for(int i=0; i<imgsize; i++) numerator[i]*=med_img[i];
          for(int i=0; i<imgsize; i++) numerator[i]-=beta*(current_img[i]-med_img[i]);
          for(int i=0; i<imgsize; i++) denominator[i]*=med_img[i];
          for(int i=0; i<imgsize; i++) denominator[i]+=beta*current_img[i];
        }
      }
      /* Calculate the next image */
      if(verbose>3) {printf("    calculating next image\n"); fflush(stdout);}
      if(osl && skip_prior<=0 && beta>0.0) { /* convex-OSL */
        if(verbose>4) printf("      convex-OSL\n");
        for(int i=0; i<imgsize; i++) {
          float f=fmaf(numerator[i], 1.0/denominator[i], 1.0);
          f*=oslcoefs[i];
          f*=current_img[i];
          if(isfinite(f)) current_img[i]=f;
        }
      } else { /* convex */
        if(verbose>4) printf("      convex\n");
        for(int i=0; i<imgsize; i++) {
          float f=fmaf(numerator[i], 1.0/denominator[i], 1.0);
          f*=current_img[i];
          if(isfinite(f)) current_img[i]=f;
        }
      }
      if(verbose>4) {printf("    -> next os_set maybe\n"); fflush(stdout);}
    } // next set
    itercount++; skip_prior--;
    if(verbose>3) {printf("  -> next iteration maybe\n"); fflush(stdout);}
  } while(itercount<iter);
  if(verbose>2) {printf("  iterations done.\n"); fflush(stdout);}

  for(int i=0; i<imgsize; i++) image[i]=current_img[i]*=scale;

  if(verbose>1) {printf("trmrp() done.\n"); fflush(stdout);}
  return(0);
}
/*****************************************************************************/

/*****************************************************************************/
