#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include "fft.h"

static int Powerof2(int v, int *m, int *twopm);
static int FFT(int dir,int m, int nn, FLOAT *r, FLOAT *i);

// given v, return m=log2(v) and tpowpm = 2^m
static int
Powerof2(int v, int *m, int *twopm)
        {
        int nn = 1;
        int mm=0;

        if (v<=0) return 0;

        while(nn<v) {nn<<=1; ++mm;}

        *m = mm;
        *twopm = nn;

        return 1;
        }

/*-------------------------------------------------------------------------
   Perform a 2D FFT inplace given a complex 2D array
   
   flags are:

	FFT_FORWARD
	FFT_REVERSE
	FFT_NORMALISE

   The size of the array (nx,ny)
   Return false if there are memory problems or
      the dimensions are not powers of 2
*/

int
FFT2D(COMPLEX *c,int flags)
{
   int width = c->width;
   int height = c->height;
   int i,j,o;
   int m,twopm;
   static FLOAT *dwn_r=NULL;
   static FLOAT *dwn_i=NULL;
   static int across=0,down=0;

   if (!Powerof2(width,&m,&twopm) || twopm != width) return(FALSE);

   for (j=o=0;j<height;j++,o+=width) {
      FFT(flags,m,width,c->real+o,c->imag+o);
      }

   if (!Powerof2(height,&m,&twopm) || twopm != height) return(FALSE);

   if (dwn_r==NULL || dwn_i == NULL || down != height) {
	if (dwn_r) VFree(dwn_r);
	if (dwn_i) VFree(dwn_i);

	dwn_r = VMalloc(width * height * S_FLOAT);
	dwn_i = VMalloc(width * height * S_FLOAT);

	down = height;
	}

   transpose_matrix(c->real,dwn_r,width,height);
   transpose_matrix(c->imag,dwn_i,width,height);
   i=width; width=height; height=i;

   // Make sure the FFT_REAL flag is cleared since it's no longer right
   flags &= ~FFT_REAL;

   for (j=o=0;j<height;j++,o+=width) {
      FFT(flags,m,width,dwn_r+o,dwn_i+o);
      }

   transpose_matrix(dwn_r,c->real,width,height);
   transpose_matrix(dwn_i,c->imag,width,height);

   return(TRUE);
}

/*****************************************************************************************
 *
 * Bit Reverse Table
 *
 * A collection of 1D tables (one for each linear dimension we are interested in)
 * that map X[n] -> X[m] where m is the bit-reversed equivalent of n
 *
 * One table is required for each unique dimension size. Generate tables for
 * sizes 1 to MAX_P2. Assume the largest size we will ever want is 64k so we can
 * use 16 bit indexes.
 */

typedef unsigned short SHORT;

static SHORT *BRT[MAX_P2+1];
static int __done_precalculate_brt = 0;

/*
 * bit reverse the bottom "b" bits of an unsigned 16 bit word
 *
 *		  4  3  2  1  0       4  3  2  1  0
 *	ie b=5 : [b4,b3,b2,b1,b0] => [b0,b1,b2,b3,b4]
 */

static inline unsigned int
__bit_reverse(SHORT n,int b)
	{
	SHORT m = 0;
	SHORT top;
	int nn=n;

	if (b==0) return n;

	top = 1 << (b-1);

	while(top) {
	   if (n&1) m |= top;
	   top>>=1; n>>=1;
	   }

	//printf("%d: reversed %d -> %d\n",b,nn,m);

	return m;
	}

static void
__precalculate_brt()
	{
	int m;
	int N;

	if (__done_precalculate_brt) return;

	N=1;
	for(m=0; m<=MAX_P2; ++m, N<<=1) {
	   int i;
	   SHORT *ptr = (SHORT *) VMalloc(N * sizeof(SHORT *));

	   for(i=0; i<N; ++i) ptr[i] = __bit_reverse(i,m);
	   BRT[m] = ptr;
	   }

	__done_precalculate_brt=1;
	}
	   
// This is the external interface for other routines that
// may want to access the bit-reverse logic.
// m is the data[] size in 2^m

void
do_bit_reverse(FLOAT *data, int m)
	{
	int i;
	int j = 1<<m;
	FLOAT tmp;
	SHORT *brt;

   	if (! __done_precalculate_brt) __precalculate_brt();
	brt = BRT[m];

	for(i=0; i<j-1; ++i) {
	   j = *(brt++);
	   if (i < j) { tmp=data[i]; data[i] = data[j]; data[j] = tmp; }
	   }
	return;
	}
	
	
/****************************************************************************************/


static FLOAT __attribute__ ((aligned (A_FLOAT))) U1_F[2<<MAX_P2];	// (u1,u2) forward (dir=1)
static FLOAT __attribute__ ((aligned (A_FLOAT))) U2_F[2<<MAX_P2];

static FLOAT __attribute__ ((aligned (A_FLOAT))) U1_R[2<<MAX_P2];	// (u1,u2) reverse (dir=-1)
static FLOAT __attribute__ ((aligned (A_FLOAT))) U2_R[2<<MAX_P2];
static int __done_precalculate_fft_twiddle=0;

static void
__precalculate_fft_twiddle()
	{
	long double c1,c2,z;
        int j,l,n,nn,l1,l2,l3;

	// only need to do this once
	if (__done_precalculate_fft_twiddle) return;

	// Do the forward direction
        c1 = -1.0;
        c2 = 0.0;
	l2 = 1;

        for(l=n=0; l<MAX_P2; ++l) {
           long double u1 = 1.0;
           long double u2 = 0.0;

           l1=l2;
	   l2<<=1;
	   l3=l1>>1;
	   nn=n;

	   if (l>=2) {
              for(j=0; j<l3; ++j) {
		U1_F[n]=u1; U2_F[n++]=u2;
                z = u1 * c1 - u2 * c2;
                u2 = u1 * c2 + u2 * c1;
                u1 = z;
                }

	      // The second half is a simple transpose of the first half
              for(j=0; j<l3; ++j) { U1_F[n]=U2_F[nn]; U2_F[n++]=-U1_F[nn++]; }
	      }

           c2 = -sqrt((1.0 - c1) / 2.0);
           c1 = sqrt((1.0 + c1) / 2.0);
           }

	// Do the reverse direction
        c1 = -1.0;
        c2 = 0.0;
	l2 = 1;

        for(l=n=0; l<MAX_P2; ++l) {
           long double u1 = 1.0;
           long double u2 = 0.0;

           l1=l2;
	   l2<<=1;
	   l3=l1>>1;
	   nn=n;

	   if (l>=2) {
              for(j=0; j<l3; ++j) {
		U1_R[n]=u1; U2_R[n++]=u2;
                z = u1 * c1 - u2 * c2;
                u2 = u1 * c2 + u2 * c1;
                u1 = z;
                }

	      // The second half is a simple transpose of the first half
              for(j=0; j<l3; ++j) { U1_R[n]=-U2_R[nn]; U2_R[n++]=U1_R[nn++]; }
	      }

           c2 = sqrt((1.0 - c1) / 2.0);
           c1 = sqrt((1.0 + c1) / 2.0);
	   }

	__done_precalculate_fft_twiddle=1;
	}

/*-------------------------------------------------------------------------
   This computes an in-place complex-to-complex FFT
   real[] and imag[] are the real and imaginary arrays of 2^m points.
   flags are:
	FFT_FORWARD:  gives forward transform
	FFT_REVERSE:  gives reverse transform
	FFT_NORMALISE: Apply normalisation pass after transforming

     Formula: forward
                  N-1
                  ---
              1   \          - j k 2 pi n / N
      X(n) = ---   >   x(k) e                    = forward transform
              N   /                                n=0..N-1
                  ---
                  k=0

      Formula: reverse
                  N-1
                  ---
                  \          j k 2 pi n / N
      X(n) =       >   x(k) e                    = forward transform
                  /                                n=0..N-1
                  ---
                  k=0

   Assumptions:

	- The smallest size handled is 4 elements, due to the optimisation that
	  rolls the first 2 passes together.

*/

static int
FFT(int flags,int m,int nn,FLOAT *real, FLOAT *imag)
{
   long i,i1,j,k,i2,l,l1,l2;
   int dir;
   SHORT *brt;
   FLOAT *R,*I,*RI,*RJ,*II,*IJ,*I2,*J2;
   FLOAT *U1,*U2;

   if (m>MAX_P2) {
    	printf("FFT: %d: Maximum power of 2 exceeded\n",m);
	exit(1);
	}
   if (m<2) {
	printf("FFT: Minimum allowed size is 4 elements\n");
	exit(0);
	}

   // Check if we need to init our tables for twiddle values and bit reversal
   if (! __done_precalculate_fft_twiddle) __precalculate_fft_twiddle();
   if (! __done_precalculate_brt) __precalculate_brt();

   if (flags&FFT_FORWARD) {
	// Forward direction
	U1 = U1_F;
	U2 = U2_F;
	dir=1;
	}
   else if (flags&FFT_REVERSE) {
	// Reverse direction
	U1 = U1_R;
	U2 = U2_R;
	dir=-1;
	}
   else {
	printf("FFT: Must have one of FFT_FORWARD or FFT_REVERSE\n");
	exit(0);
	}

   /************************************************************************
    * Do the bit reversal via lookup table BRT[]
    *
    * reverse the real[] and imag[] arrays separately to improve 
    * cache locality
    */

   if (! (flags & FFT_NO_BRT)) {
      // bit-reverse real[] array
      R = real;
      brt = BRT[m];
      for(i=0; i<nn-1; ++i) {
	FLOAT tmp;
	SHORT j = *(brt++);

	if (i < j) { tmp=*R; *R = real[j]; real[j] = tmp; }
	++R;
	}

      // bit-reverse imag[] array
      if (! (flags&FFT_REAL)) {
         I = imag;
         brt = BRT[m];
         for(i=0; i<nn-1; ++i) {
	   FLOAT tmp;
	   SHORT j = *(brt++);

	   if (i < j) { tmp=*I; *I = imag[j]; imag[j] = tmp; }
	   ++I;
	   }
         }
      }

   /************************************************************************
    * merge pass 1 and pass 2
    */

   RI = real; RJ = RI + 2;
   II = imag; IJ = II + 2;

   if (flags & FFT_REAL) for(i=0; i<nn; i+=4) {
        register FLOAT e1 = *RI + RI[1]; // (r0+r1)
        register FLOAT e3 = *RJ + RJ[1]; // (r2+r3)
        
        // left
        *(RI++) = e1 + e3;
        *(RJ++) = e1 - e3;
        
        e1 -= 2 * *RI;  // (r0-r1)
        e3 -= 2 * *RJ;  // (r2-r3)
        
        // right
        e3 *= dir; 
        
        *RI = e1;
        *RJ = e1;
        *(++II) = -e3;
        *(++IJ) = e3;
        
        RI += 3; II += 3;
        RJ += 3; IJ += 3;
        }
   else for(i=0; i<nn; i+=4) {
	FLOAT tr2 = *RJ;		 // r2
	FLOAT ti2 = *IJ;		 // i2
	register FLOAT e1 = *RI + RI[1]; // (r0+r1)
	register FLOAT e3 = tr2 + RJ[1]; // (r2+r3)
	register FLOAT e2 = *II + II[1]; // (i0+i1)
	register FLOAT e4 = ti2 + IJ[1]; // (i2+i3)

	// left
	*RI = e1 + e3;
	*RJ = e1 - e3;
	*II = e2 + e4;
	*IJ = e2 - e4;

	RI++; II++;
	RJ++; IJ++;

	e1 -= 2 * *RI;	// (r0-r1)
	e2 -= 2 * *II;	// (i0-i1)

	if (dir==1) {
	   e3 -= 2 * *RJ;	// (r2-r3)
	   e4 -= 2 * *IJ;	// (i2-i3)
	   }
	else {
	   e3 -= 2 * tr2;	// (r3-r2)
	   e4 -= 2 * ti2;	// (i3-i2)
	   }

	*RI = e1 + e4;
	*RJ = e1 - e4;
	*II = e2 - e3;
	*IJ = e2 + e3;

	RI += 3; II += 3;
	RJ += 3; IJ += 3;
	}

//for(i=0; i<nn; ++i) printf("%2.2f%s ",real[i],imag[i]==0?"":"*"); printf("\n");

   /* Compute the remaining passes using lots of SIMD goodness */
   l2 = 4;

   for (l=2;l<m;l++) {
      register FLOAT *u1,*u2;
      f4vector t1,t2;

      RI = real; II = imag;
      RJ = real + l2; IJ = imag + l2;

      l1 = l2;
      l2 <<= 1;

      i=0; while(i<nn) {
	 u1 = U1;
	 u2 = U2;
	 while(l1) {
	    t1.v = VFLOAT(u1) * VFLOAT(RJ) - VFLOAT(u2) * VFLOAT(IJ);
            t2.v = VFLOAT(u1) * VFLOAT(IJ) + VFLOAT(u2) * VFLOAT(RJ);
            VFLOAT(RJ) = VFLOAT(RI) - t1.v;
            VFLOAT(IJ) = VFLOAT(II) - t2.v;
            VFLOAT(RI) += t1.v;
            VFLOAT(II) += t2.v;

	    u1 += 4; u2 += 4;
	    RI += 4; RJ += 4;
	    II += 4; IJ += 4;
	    l1 -= 4;
            }

	i += l2;
	l1 = (l2>>1);
	RI += l1; RJ += l1; II += l1; IJ += l1;
	}

      U1 += l1;
      U2 += l1;
//for(i=0; i<nn; ++i) printf("%2.2f%s ",real[i],imag[i]==0?"":"*"); printf("\n"); printf("\n");
      }

   /* Scaling for forward transform */
   if (flags & FFT_NORMALISE) {
      f4vector nv;

      nv.f[0]=nn; nv.f[1]=nn;
      nv.f[2]=nn; nv.f[3]=nn;

      nn>>=2;
      while(nn-- > 0) {
	VFLOAT(real) /= nv.v;
	VFLOAT(imag) /= nv.v;
	real+=4;
	imag+=4;
	}
   }

   return(TRUE);
}

/* Force data to be 16-byte aligned in case we want to use sse vectors */

COMPLEX *
create_complex(int w, int h)
        {
        COMPLEX *c = (COMPLEX *) VMalloc(sizeof(COMPLEX));

        c->width = w;
        c->height = h;

        c->real = (FLOAT *) VMalloc(S_FLOAT * w * h);
        c->imag = (FLOAT *) VMalloc(S_FLOAT * w * h);

        return c;
        }

void
destroy_complex(COMPLEX *c)
        {
        VFree(c->real);
        VFree(c->imag);
	VFree(c);
        }

/* 
 * allocate a block of memory with malloc() and align it to the natural size of our vector
 * type (f4vector). 
 *
 * We do this by allocating sizeof(f4vector) extra bytes and then moving the start pointer 
 * forward by enough bytes to place on the correct boundary. We store an unsigned char
 * value at offset (-1) indicating how far we have moved this pointer so we can reset the
 * pointer before calling free().
 *
 * Note if the value returned by malloc() was already aligned then we waste the most memory,
 * by moving our pointer forward by sizeof(f4vector).

 * We also allocate an extra 16 bytes over and above the 16 needed for alignment in case
 * the caller will be using 16-byte vector operations on non-vector data, making sure
 * there's room at the end for the last vector op.
 */

static int __v_align = sizeof(f4vector);
static int __done_align_check = 0;

static inline void
__do_align_check()
	{
	if (! __done_align_check) {
	   // vector size must be a power of two for this to work
	   int i=1;

	   while(i && i != __v_align) i<<=1;

	   if (i!=__v_align) {
	      printf("VMalloc: Vector size %d (sizeof(f4vector)) non power of two\n",__v_align);
	      exit(1);
	      }
	   __done_align_check = 1;
	   }
	}

void *
VMalloc(int n)
	{
	unsigned int i;
	unsigned char *buf;
	unsigned char *ptr,ch;

	__do_align_check();

	buf = (unsigned char *) malloc(n + __v_align + 16);
	if (buf == NULL) {
	   printf("VMalloc: out of memory requesting %d bytes\n",n);
	   exit(1);
	   }

	// set ptr to next alignment boundary inside buf
	ptr = (unsigned char *) (((unsigned int)buf+__v_align) & ~(__v_align-1));

	// how far did we move?
	ch = (unsigned int)(ptr - buf);

	// Store the alignment offset in the preceding byte
	*(ptr-1) = ch;

	return ptr;
	}

/*
 * Free memory allocated with VMalloc().
 *
 * The unsigned byte at location *(ptr-1) gives the offset back to the address
 * provided by malloc.
 */

void
VFree(void *buf)
	{
	unsigned char *ptr = (unsigned char *)buf;
	unsigned int ch = *(ptr-1);
	int i,loff;

	ptr -= ch;

	free(ptr);
	}

inline void
_vector_zero_F(FLOAT *dst, int count)
        {
        f4vector z;

        LOADVC(z,0);

        count /= V_ELEMENTS;
        while(count--) {
           VFLOAT(dst) = z.v;
           dst += V_ELEMENTS;
           }
        }
inline void
_vector_copy_F2F(FLOAT *src, FLOAT *dst, int count)
        {
        while(count) {
           VFLOAT(dst) = VFLOAT(src);
           count -= V_ELEMENTS;
           src += V_ELEMENTS;
           dst += V_ELEMENTS;
           }
        }

inline void
_vector_multiply_Ff(FLOAT *data, float f, int count)
	{
        f4vector z;

        LOADVC(z,f);

        count /= V_ELEMENTS;
        while(count--) {
           VFLOAT(data) *= z.v;
           data += V_ELEMENTS;
           }
        }
