#include "ninox.h"

typedef struct {
   double real;
   double imag;
   } COMPLEX;

static int IdealLowPass(COMPLEX **c, int c_width, int c_height, double radius);
static int IdealHighPass(COMPLEX **c, int c_width, int c_height, double radius);
static int SwapDiagonal(COMPLEX **c, int c_width, int c_height);
static int LowPass1(COMPLEX **c, int c_width, int c_height, double radius);
static int LowPass2(COMPLEX **c, int c_width, int c_height, double radius);

// given v, return m=log2(v) and pwopm = 2^m
static int
Powerof2(int v, int *m, int *twopm)
	{
	if (v<=0) return 0;

	*m = (int) ceil(log2((double)v));
	*twopm = (int) pow(2.0, (double)*m);

	return 1;
	}

/*-------------------------------------------------------------------------
   Perform a 2D FFT inplace given a complex 2D array
   The direction dir, 1 for forward, -1 for reverse
   The size of the array (nx,ny)
   Return false if there are memory problems or
      the dimensions are not powers of 2
*/
static int
FFT2D(COMPLEX **c,int nx,int ny,int dir)
{
   int i,j;
   int m,twopm;
   double *real,*imag;

   /* Transform the rows */
   real = (double *)malloc(nx * sizeof(double));
   imag = (double *)malloc(nx * sizeof(double));
   if (real == NULL || imag == NULL)
      return(FALSE);
   if (!Powerof2(nx,&m,&twopm) || twopm != nx)
      return(FALSE);
   for (j=0;j<ny;j++) {
      for (i=0;i<nx;i++) {
         real[i] = c[i][j].real;
         imag[i] = c[i][j].imag;
      }
      FFT(dir,m,real,imag);
      for (i=0;i<nx;i++) {
         c[i][j].real = real[i];
         c[i][j].imag = imag[i];
      }
   }
   free(real);
   free(imag);

   /* Transform the columns */
   real = (double *)malloc(ny * sizeof(double));
   imag = (double *)malloc(ny * sizeof(double));
   if (real == NULL || imag == NULL)
      return(FALSE);
   if (!Powerof2(ny,&m,&twopm) || twopm != ny)
      return(FALSE);
   for (i=0;i<nx;i++) {
      for (j=0;j<ny;j++) {
         real[j] = c[i][j].real;
         imag[j] = c[i][j].imag;
      }
      FFT(dir,m,real,imag);
      for (j=0;j<ny;j++) {
         c[i][j].real = real[j];
         c[i][j].imag = imag[j];
      }
   }
   free(real);
   free(imag);

   return(TRUE);
}

/*-------------------------------------------------------------------------
   This computes an in-place complex-to-complex FFT
   x and y are the real and imaginary arrays of 2^m points.
   dir =  1 gives forward transform
   dir = -1 gives reverse transform

     Formula: forward
                  N-1
                  ---
              1   \          - j k 2 pi n / N
      X(n) = ---   >   x(k) e                    = forward transform
              N   /                                n=0..N-1
                  ---
                  k=0

      Formula: reverse
                  N-1
                  ---
                  \          j k 2 pi n / N
      X(n) =       >   x(k) e                    = forward transform
                  /                                n=0..N-1
                  ---
                  k=0
*/
int FFT(int dir,int m,double *x,double *y)
{
   long nn,i,i1,j,k,i2,l,l1,l2;
   double c1,c2,tx,ty,t1,t2,u1,u2,z;

   /* Calculate the number of points */
   nn = 1;
   for (i=0;i<m;i++)
      nn *= 2;

   /* Do the bit reversal */
   i2 = nn >> 1;
   j = 0;
   for (i=0;i<nn-1;i++) {
      if (i < j) {
         tx = x[i];
         ty = y[i];
         x[i] = x[j];
         y[i] = y[j];
         x[j] = tx;
         y[j] = ty;
      }
      k = i2;
      while (k <= j) {
         j -= k;
         k >>= 1;
      }
      j += k;
   }

   /* Compute the FFT */
   c1 = -1.0;
   c2 = 0.0;
   l2 = 1;
   for (l=0;l<m;l++) {
      l1 = l2;
      l2 <<= 1;
      u1 = 1.0;
      u2 = 0.0;
      for (j=0;j<l1;j++) {
         for (i=j;i<nn;i+=l2) {
            i1 = i + l1;
            t1 = u1 * x[i1] - u2 * y[i1];
            t2 = u1 * y[i1] + u2 * x[i1];
            x[i1] = x[i] - t1;
            y[i1] = y[i] - t2;
            x[i] += t1;
            y[i] += t2;
         }
         z =  u1 * c1 - u2 * c2;
         u2 = u1 * c2 + u2 * c1;
         u1 = z;
      }
      c2 = sqrt((1.0 - c1) / 2.0);
      if (dir == 1)
         c2 = -c2;
      c1 = sqrt((1.0 + c1) / 2.0);
   }

   /* Scaling for forward transform */
   if (dir == 1) {
      for (i=0;i<nn;i++) {
         x[i] /= (double)nn;
         y[i] /= (double)nn;
      }
   }

   return(TRUE);
}

struct Image *
do_fft(struct Image *img)
	{
	int new_width,new_height;
	int x,y,o,o1,padx,pady,mx,my;
	double max;
	struct Image *new;
	unsigned short *udata;
	COMPLEX **c;

	// the output image must be a power of 2 in both dimensions
	new_width = (int) pow(2.0, ceil(log2((double)img->width)));
	new_height = (int) pow(2.0, ceil(log2((double)img->height)));

	c = ZeroMalloc(sizeof(COMPLEX *) * new_width);
	for(x=0; x<new_width; ++x) 
	   c[x] = ZeroMalloc(sizeof(COMPLEX) * new_height);

	Print("FFT image (%d x %d x %d)\n",new_width,new_height,img->depth);

	// Copy the original image into the centre of this space
	padx = (new_width - img->width) / 2;
	pady = (new_height - img->height) / 2;

	if (img->depth > 16) {
	   Print("FFT not supported on images other than 8/16 bpp monochrome\n");
	   exit(0);
	   }

	udata = (unsigned short *)img->data;
	for(y=0; y<img->height; ++y) {
	   int o = y * img->width;
	   for(x=0; x<img->width; ++x) {
		if (img->depth==8) c[padx+x][pady+y].real = ((double)img->data[o++] * 256);
		else c[padx+x][pady+y].real = (double)udata[o++];
	        }
	   }

	// Perform the FFT
	if (!  FFT2D(c,new_width,new_height,1)) {
	   Print("FFT2D failed!\n");
	   exit(0);
	   }

	/* re-arrange to put the low frequency in the centre */
	SwapDiagonal(c,new_width,new_height);
	
	// Copy the data
	udata = (unsigned short *)ZeroMalloc(new_width * new_height * 2);
	for(y=0; y<new_height; ++y) {
	   int o = y * new_width;
	   for(x=0; x<new_width; ++x,++o) {
		int v = sqrt(c[x][y].real*c[x][y].real + c[x][y].imag*c[x][y].imag);
		if (v>65535) v=65535;
		udata[o] = (unsigned short)v;
		}
	   }
	new = CreateFitsImage("fft_mag.fit", new_width, new_height, 16, (unsigned char *)udata, -1, -1.0);
	WriteImage(new);

	LowPass1(c,new_width,new_height,(double)100.0);

	/* Now put it back where it belongs */
	SwapDiagonal(c,new_width,new_height);

	if (! FFT2D(c,new_width,new_height,-1)) {
	   Print("Reverse FFT failed\n");
	   exit(1);
	   }

	for(y=0; y<new_height; ++y) {
	   int o = y * new_width;
	   for(x=0; x<new_width; ++x,++o) {
		int v = sqrt(c[x][y].real*c[x][y].real + c[x][y].imag*c[x][y].imag);
		if (v>65535) v=65535;
		udata[o] = (unsigned short)v;
		}
	   }

	new = CreateFitsImage("fft_rev.fit", new_width, new_height, 16, (unsigned char *)udata, -1, -1.0);
	WriteImage(new);

	return(new);
	}

/*
 * Transpose
 *			2  1
 *
 *			3  4
 *
 * into
 *
 *			4  3
 *
 *			1  2
 *
 * to put the low frequency vertex at the origin
 */

static int
SwapDiagonal(COMPLEX **c, int c_width, int c_height)
	{
	int my = c_height/2;
	int mx = c_width/2;
	int x,y;

	for(y=0; y<my; ++y)
	   for(x=0; x<c_width; ++x) {
		int x1 = (x+mx)&(c_width-1);
		int y1 = (y+my) & (c_height-1);

		double r = c[x][y].real;
		double i = c[x][y].imag;

		c[x][y].real = c[x1][y1].real;
		c[x][y].imag = c[x1][y1].imag;

		c[x1][y1].real = r;
		c[x1][y1].imag = i;
		}
	   }
	
static int
IdealHighPass(COMPLEX **c, int c_width, int c_height, double radius)
	{
	int x,y;
	int cx,cy;

	cx = c_width/2;
	cy = c_height/2;

	for(y=-radius; y<=radius; ++y)
	   for(x=-radius; x<=radius; ++x) {
		double d = sqrt(x*x + y*y);
		if (d<=radius) {
		   c[x+cx][y+cy].real *= d/radius;
		   }
		}
	}

static int
IdealLowPass(COMPLEX **c, int c_width, int c_height, double radius)
	{
	int x,y;
	int cx,cy;

	cx = c_width/2;
	cy = c_height/2;

	for(y=-cy; y<cy; ++y)
	   for(x=-cx; x<cx; ++x) {
		double d = sqrt(x*x + y*y);
		if (d>radius) {
		   c[x+cx][y+cy].real  = 0;
		   c[x+cx][y+cy].imag  = 0;
		   }
		}
	}

static int
LowPass1(COMPLEX **c, int c_width, int c_height, double radius)
	{
	int x,y;
	int cx,cy;

	cx = c_width/2;
	cy = c_height/2;

	for(y=-cy; y<cy; ++y)
	   for(x=-cx; x<cx; ++x) {
		double d = sqrt(x*x + y*y);
		if (d>radius) {
		   double p = 2 - d/radius; if(p<0) p=0;
		   c[x+cx][y+cy].real *= p;
		   c[x+cx][y+cy].imag *= p;
		   }
		}
	}

static int
LowPass2(COMPLEX **c, int c_width, int c_height, double radius)
	{
	int x,y;
	int cx,cy;

	cx = c_width/2;
	cy = c_height/2;

	for(y=-cy; y<cy; ++y)
	   for(x=-cx; x<cx; ++x) {
		double d = sqrt(x*x + y*y);
		if (d>radius) {
		   double p = 2 - d/radius; if(p<0) p=0;
		   p *= p;
		   c[x+cx][y+cy].real *= p;
		   c[x+cx][y+cy].imag *= p;
		   }
		}
	}
