#include "ninox.h"

static int apply_filter(COMPLEX *c, FLOAT *f);
static COMPLEX * do_fft(struct Image *img);
static int image_from_complex(COMPLEX *c, struct Image *img);

static inline void Complex_Conjugate(FLOAT *rx, FLOAT *ix, FLOAT ry, FLOAT iy);
static inline void Complex_Multiply(FLOAT *rz, FLOAT *iz, FLOAT rx, FLOAT ix, FLOAT ry, FLOAT iy);
static inline FLOAT Complex_Length(FLOAT r, FLOAT i);
static COMPLEX * correlate(COMPLEX *f1, COMPLEX *f2);

static struct Image * Clone_PadImage(struct Image *img, int new_width, int new_height);
static int find_offset(COMPLEX *c, int region, int *dx, int *dy);

static FLOAT * make_LowPass(int c_width, int c_height, double r1, double r2, double p);
static int apply_filter(COMPLEX *, FLOAT *);
static int create_magnitude_map(COMPLEX *c,unsigned short *buf);
static int create_phase_map(COMPLEX *c, unsigned short *buf);

static void SwapDiagonal(COMPLEX *c);
static FLOAT * SwapDiagonal_R(FLOAT *, int, int);

static int
apply_filter(COMPLEX *c, FLOAT *f)
	{
	int npix = c->width * c->height;
	FLOAT *r = c->real;
	FLOAT *i = c->imag;

	npix /= V_ELEMENTS;
	while(npix--) {
	   VFLOAT(r) *= VFLOAT(f);
	   VFLOAT(i) *= VFLOAT(f);
	   r+=V_ELEMENTS; i+=V_ELEMENTS; f+=V_ELEMENTS;
	   }

	return 1;
	}

static COMPLEX *
do_fft(struct Image *img)
	{
	COMPLEX *c;
	int o;
	int npix = img->width * img->height;
	unsigned char *data;
	unsigned short *udata;
	FLOAT *r;

	if (img->depth != 8 && img->depth != 16) {
	   Print("do_fft: FFT not supported on images other than 8 bpp monochrome\n");
	   return(NULL);
	   }

	c = create_complex(img->width,img->height);
	r = c->real;

	switch(img->depth) {
	   case 8:
		data = img->data;
		for(o=0; o<npix; ++o) *(r++) = data[o];
		break;
	   case 16:
		udata = (unsigned short *)img->data;
		for(o=0; o<npix; ++o) *(r++) = udata[o];
		break;
	   }

	_vector_zero_F(c->imag,npix);

	// Perform the FFT
	if (!  FFT2D(c,FFT_FORWARD|FFT_REAL|FFT_NO_BRT)) {
	   Print("FFT2D failed!\n");
	   return(NULL);
	   }

	return(c);
	}

// Apply the reverse fft and recover the image data
static int
image_from_complex(COMPLEX *c, struct Image *img)
	{
	int c_width = c->width;
	int c_height = c->height;
	int c_npix = c_width * c_height;
	int i_width = img->width;
	int i_height = img->height;
	int i_npix = i_width * i_height;
	unsigned short *udata;
	int x,y,o,x_offset,y_offset;
	FLOAT *off,max;

	// Maybe storing complex grid into non-power-of-two image grid
	// which means we fetch data only from a subregion in the complex grid
	x_offset = (c_width - i_width) / 2;
	y_offset = (c_height - i_height) / 2;

	// prescan region to find maximum value for scaling
	for(y=max=0; y<i_height; ++y) {
	   off = c->real + (y+y_offset) * c_width + x_offset;
	   for(x=0; x<i_width; ++x,++off)
	      if (*off>max) max=*off;
	   }

	if (img->depth == 8) {
	   for(y=o=0; y<i_height; ++y) {
	   	off = c->real + (y+y_offset) * c_width + x_offset;
		for(x=0; x<i_width; ++x,++o,++off) {
		   int v = *off * 255.0/max; if (v<0) v=0; if (v>255) v=255;
		   img->data[o] = v;
		   }
		}
	   }
	else if (img->depth == 16) {
	   udata = (unsigned short *) img->data;
	   for(y=o=0; y<i_height; ++y) {
		off = c->real + (y+y_offset) * c_width + x_offset;
		for(x=0; x<i_width; ++x,++o,++off) {
		   int v = *off * 32767.0/max; if (v<0) v=0; if (v>65535) v=65535;
		   udata[o] = v;
		   }
		}
	   }

	return(1);
	}

// x = conj(y)
static inline void
Complex_Conjugate(FLOAT *rx, FLOAT *ix, FLOAT ry, FLOAT iy)
	{
	*rx = ry;
	*ix = - iy;
	}

// z = x * y
static inline void
Complex_Multiply(FLOAT *rz, FLOAT *iz, FLOAT rx, FLOAT ix, FLOAT ry, FLOAT iy)
	{
	*rz = rx * ry - ix * iy;
	*iz = rx * iy + ry * ix;
	}

static inline FLOAT
Complex_Length(FLOAT r, FLOAT i) { return sqrt(r*r + i*i); }

static COMPLEX *
correlate(COMPLEX *f1, COMPLEX *f2)
	{
	int o;
	int width = f1->width;
	int height = f1->height;
	int npix = width * height;
	COMPLEX *tmp = create_complex(width,height);
	FLOAT conj_r,conj_i,c1_r,c1_i,c2_r,c2_i;
	FLOAT d,m_r,m_i;
	
	for(o=0; o<npix; ++o) {
	   c1_r = f1->real[o]; c1_i = f1->imag[o];
	   c2_r = f2->real[o]; c2_i = f2->imag[o];

	   d = Complex_Length(c1_r,c1_i) * Complex_Length(c2_r,c2_i);
	   if (d==0) d=0.000001;
	   Complex_Conjugate(&conj_r,&conj_i,c2_r,c2_i);
	   Complex_Multiply(&m_r,&m_i,c1_r,c1_i,conj_r,conj_i);
	   tmp->real[o] = m_r/d;
	   tmp->imag[o] = m_i/d;
	   }

	return tmp;
	}

static struct Image *
Clone_PadImage(struct Image *img, int new_width, int new_height)
	{
	struct Image *Img = CloneImage(img);
	unsigned char *src,*dst;
	unsigned short *usrc,*udst;
	int padx = (new_width - img->width)/2;
	int pady = (new_height - img->height)/2;
	int bpp = img->depth / 8;
	int x,y,o,O;

	if (img->width != new_width || img->height != new_height) {
	   free(Img->data);
	   Img->data = ZeroMalloc(new_width * new_height * bpp);
	   Img->width = new_width;
	   Img->height = new_height;

	   switch(img->depth) {
	      case 8:
		src = img->data;
		dst = Img->data;
		for(y=O=0; y<img->height; ++y) {
		   o = (y+pady) * new_height + padx;
		   for(x=0; x<img->width; ++x,++o,++O) dst[o] = src[O];
		   }
		break;
	      case 16:
		usrc = (unsigned short *)img->data;
		udst = (unsigned short *)Img->data;
		for(y=O=0; y<img->height; ++y) {
		   o = (y+pady) * new_height + padx;
		   for(x=0; x<img->width; ++x,++o,++O) udst[o] = usrc[O];
		   }
		break;
	      }
	   }

	return Img;
	}

int
fft_filter(struct Image *img, double lpf_r1, double lpf_r2, double lpf_p)
	{
	int i,w,h,x,y,o,O,padx,pady;
	unsigned char *src,*dst,*orig_dst;
	unsigned short *usrc,*udst;
	int width,height;
	int bpp = img->depth/8;
	int npixels = width * height;
	int c_npixels;
	int x1,y1,x2,y2;
	int maxv;
	FLOAT max;
	static COMPLEX *fft = NULL;
	static int lpf_w=-1,lpf_h=-1;
	static FLOAT *filter=NULL;
	static double lpf_r1_s;
	static double lpf_r2_s;
	static double lpf_p_s;

	x1 = img->cutout.x1; x2 = img->cutout.x2;
	y1 = img->cutout.y1; y2 = img->cutout.y2;
	width = (x2-x1+1);
	height = (y2-y1+1);

	// Find (w,h) as the nearest larger power of 2
	w=1; while(w < width) w<<=1;
	h=1; while(h < height) h<<=1;

	// Make the array square
	while(w<h) w<<=1; while(h<w) h<<=1;

	c_npixels = w * h;

	// padding needed
	padx = (w-width)/2;
	pady = (h-height)/2;

	// ***************************************************
	// * create our 2D complex array from the image data
	// ***************************************************

	if (fft==NULL) {
	   fft = create_complex(w,h);
	   }

	// zero both real and imag elements
	_vector_zero_F(fft->real,c_npixels);
	_vector_zero_F(fft->imag,c_npixels);

	maxv=0;

	// Copy data into complex array (real part), find maxv
	switch(img->depth) {
	      case 8:
		src = img->data;
		for(y=y1; y<=y2; ++y) {
		   o = y*img->width;
		   FLOAT *off = fft->real + (pady+y-y1) * w + padx;
		   for(x=x1; x<=x2; ++x,++off) {
		      *off = (FLOAT)src[o+x];
	   	      if (src[o+x]>maxv) maxv=src[o+x];
		      }
		   }
		break;
	      case 16:
		usrc = (unsigned short *)img->data;
		for(y=y1; y<=y2; ++y) {
		   o = y*img->width;
		   FLOAT *off = fft->real + (pady+y-y1) * w + padx;
		   for(x=x1; x<=x2; ++x,++off) {
		      *off = (FLOAT)usrc[o+x];
	   	      if (usrc[o+x]>maxv) maxv=usrc[o+x];
		      }
		   }
		break;
	     default:
		Print("fft_filter: depth '%d' not supported, must be 8 or 16\n",img->depth);
		exit(1);
		break;
	     }

	// ***************************************************
	// * create the lowpass filter
	// ***************************************************

	if (filter==NULL)
	   filter = make_LowPass(w,h,lpf_r1,lpf_r2,lpf_p);
	else if (lpf_r1_s != lpf_r1 || lpf_r2_s != lpf_r2 || lpf_p_s != lpf_p ||
		lpf_w != w|| lpf_h != h) {
		VFree(filter);
	    	filter = make_LowPass(w,h,lpf_r1,lpf_r2,lpf_p);
		}
	lpf_r1_s = lpf_r1;
	lpf_r2_s = lpf_r2;
	lpf_p_s = lpf_p;
	lpf_w = w;
	lpf_h = h;

	Print("lowpass filter radius %lf/%lf, power %lf\n",lpf_r1,lpf_r2,lpf_p);
	
	// **********************************************************
	// * do the forward fft, apply the filter, do the reverse fft
	// **********************************************************

	printf("fft %dx%d\n",fft->width,fft->height); fflush(stdout);

	if (!  FFT2D(fft,FFT_FORWARD|FFT_REAL)) {
	   printf("fft_filter: Forward FFT failed\n");
	   exit(1);
	   }

	apply_filter(fft,filter);

	if (! FFT2D(fft,FFT_REVERSE)) {
	   Print("fft_filter: Reverse FFT failed\n");
	   exit(1);
	   }

	// ******************************************************
	// * Regenerate the output image, scaled from the data
	// ******************************************************

	for(y=y1,max=0; y<=y2; ++y) {
	   FLOAT *off = fft->real + (pady+y-y1) * w + padx;
	   for(x=x1; x<=x2; ++x,++off) if (*off > max) max = *off;
	   }

	// Scale the real component
	_vector_multiply_Ff(fft->real,(FLOAT)maxv/max,c_npixels);

	// put the data back into the image
	switch(img->depth) {
	   case 8:
		dst = img->data;
		for(y=y1; y<=y2; ++y) {
		   o = y * img->width;
	   	   FLOAT *off = fft->real + (pady+y-y1) * w + padx;
	   	   for(x=x1; x<=x2; ++x,++off) {
		      FLOAT v = *off;
		      if (v<0) v=0; if (v>maxv) v=maxv;
		      dst[o+x] = v;
		      }
		   }

		break;
	      case 16:
		udst = (unsigned short *)img->data;
		for(y=y1; y<=y2; ++y) {
		   o = y * img->width;
	   	   FLOAT *off = fft->real + (pady+y-y1) * w + padx;
	   	   for(x=x1; x<=x2; ++x,++off) {
		      FLOAT v = *off;
		      if (v<0) v=0; if (v>maxv) v=maxv;
		      udst[o+x] = v;
		      }
		   }
		break;
	   }

	//destroy_complex(fft);
	return 1;
	}

// Possibly pad the input images to be a power of two before processing
int
fft_register_pad(struct Image *ref, struct Image *img, double lpf_r1, double lpf_r2, double lpf_p, int region, int *dx, int *dy)
	{
	struct Image *Ref,*Img;
	int w,h,rval;

	if (ref->width != img->width || ref->height != img->height) {
	   printf("fft_register: Images must be the same dimensions\n");
	   return 0;
	   }

	w=1; while(w < ref->width) w<<=1;
	h=1; while(h < ref->height) h<<=1;

	// Make the array square
	while(w<h) w<<=1; while(h<w) h<<=1;

	if (ref->width != w || ref->height != h) {
	   // image dimensions need padding
	   Ref = Clone_PadImage(ref,w,h);
	   Img = Clone_PadImage(img,w,h);
	   }
	else {
	   Ref = ref;
	   Img = img;
	   }

	rval = fft_register(Ref,Img,lpf_r1,lpf_r2,lpf_p,region,dx,dy);

	if (Ref != ref) DestroyImage(Ref);
	if (Img != img) DestroyImage(Img);

	return rval;
	}


int
fft_register(struct Image *ref, struct Image *img, double lpf_r1, double lpf_r2, double lpf_p, int region, int *dx, int *dy)
	{
	int width = ref->width;
	int height = ref->height;
	int npixels = width * height;
	struct Image *new;
	COMPLEX *f_img,*f_ref,*tmp;
	int i,x,y;
	static FLOAT *lp_filter = NULL;

	*dx = 0;
	*dy = 0;

	// Images must be the same size
	if (ref->width != img->width || ref->height != img->height) {
	   printf("fft_register: Images must be the same dimensions\n");
	   return 0;
	   }

	if (ref->depth != img->depth) {
	   printf("fft_register: images must be the same depth\n");
	   return 0;
	   }

	i=1; while(i < ref->width) i<<=1;
	if (ref->width != i) {
	   printf("Image width must be a power of 2\n");
	   return 0;
	   }

	f_ref = do_fft(ref);
	if (! f_ref) {
	   printf("fft_register: fft failed on reference image\n");
	   return 0;
	   }

Print("fft_register: reference(%s) fft size %dx%dx%d\n",ref->src_fname,width,height,ref->depth);

	f_img = do_fft(img);
	if (! f_img) {
	   printf("fft_register: fft failed on image\n");
	   return 0;
	   }

Print("fft_register: image(%s) fft size %dx%dx%d\n",img->src_fname,width,height,img->depth);

	if (lpf_r1 != 0) {
	   if (lp_filter==NULL) lp_filter = make_LowPass(width,height,lpf_r1,lpf_r2,lpf_p);
	   Print("lowpass fft radius %lf/%lf, power %lf\n",lpf_r1,lpf_r2,lpf_p);
	   apply_filter(f_ref,lp_filter);
	   apply_filter(f_img,lp_filter);
	   }

	tmp = correlate(f_ref,f_img);

printf("done correlate\n");

	if (! FFT2D(tmp,FFT_REVERSE)) {
	   Print("Reverse FFT failed\n");
	   return 0;
	   }

	find_offset(tmp,region,dx,dy);

printf("offset %d,%d\n",*dx,*dy);

	new = CloneImage(ref);
	SwapDiagonal(tmp);
	image_from_complex(tmp,new);
	new->dst_fname = "fft_delta.fit";
	WriteImage(new);
	
	// visualise the reference image
	FFT2D(f_ref,FFT_REVERSE);
	image_from_complex(f_ref,new);
	new->dst_fname = "fft_ref_lp.fit";
	WriteImage(new);

	destroy_complex(tmp);

	return(1);
	}

// Given a complex array, assume that only the real portion contains data *ie a real image)
// find the highest peak relative to the centre
static int
find_offset(COMPLEX *c, int region, int *dx, int *dy)
	{
	int width = c->width;
	int height = c->height;
	int x,y,npix;
	FLOAT max=0;
	double total_x,total_y;
	int h2,w2;

	w2 = *dx = width/2;
	h2 = *dy = height/2;

	for(y=-region; y<=region; ++y)
	   for(x=-region; x<=region; ++x) {
		int x1 = x & (width-1);
		int y1 = y & (height-1);
	   	FLOAT *off = c->real + y1*width+x1;
		if (*off > max) max = *off;
		}

	total_x = total_y = 0;
	for(npix=0,y=-region; y<=region; ++y)
	   for(x=-region; x<=region; ++x) {
		int x1 = x & (width-1);
		int y1 = y & (height-1);
	   	FLOAT *off = c->real + y1*width+x1;
	      	double v = *off / max;
	      	if (v > 0.5) {
		   total_x += x;
		   total_y += y;
		   npix++;
		   }
		}

	if (npix==0) npix=1;

	*dx = total_x / npix;
	*dy = total_y / npix;

	return 1;
	}

static int
create_magnitude_map(COMPLEX *c, unsigned short *buf)
	{
	int width = c->width;
	int height = c->height;
	int npix = width * height;
	int o;
	double max,d,scale;
	FLOAT *real = c->real;
	FLOAT *imag = c->imag;

	// first pass to get min/max

	for(max=o=0; o<npix; ++o) {
	   double d = log(sqrt(real[o]*real[o] + imag[o]*imag[o]));
	   if (d>max) max=d;
	   }

	scale = 32700 / max;

	for(o=0; o<npix; ++o) {
	   double d = log(sqrt(real[o]*real[o] + imag[o]*imag[o]));
	   int v = log(d) * scale;
	   if (v>32700) v=32700; if (v<0) v=0;
	   buf[o] = (unsigned short)v;
	   }

	return 1;
	}

static int
create_phase_map(COMPLEX *c, unsigned short *buf)
	{
	int o;
	int width = c->width;
	int height = c->height;
	int npix = width * height;
	double d,max,scale;
	FLOAT *real = c->real;
	FLOAT *imag = c->imag;

	// first pass is to get scaling values

	for(max=o=0; o<npix; ++o) {
	   double d;
	   if (real[o]) d = log(atan(imag[o] / real[o]));
	   else d=0;

	   if (d>max) max=d;
	   }

	// scale to have brightest close to max
	scale = 32700 / max;

	for(max=o=0; o<npix; ++o) {
	   double d;
	   if (real[o]) d = log(atan(imag[o] / real[o]));
	   else d=0;

	   if (d<0) d=0; if (d>32700) d=32700;
	   buf[o] = (unsigned short)d;
	   }

	return 1;
	}

/*
 * Transpose
 *			2  1
 *
 *			3  4
 *
 * into
 *
 *			4  3
 *
 *			1  2
 *
 * to put the low frequency vertex at the origin
 */

static FLOAT *
SwapDiagonal_R(FLOAT *r, int r_width, int r_height)
	{
	int my = r_height/2;
	int mx = r_width/2;
	int x,y,o;

	for(o=y=0; y<my; ++y) {
	   int y1 = (y+my) & (r_height-1);
	   y1 *= r_width;
	   for(x=0; x<r_width; ++x,++o) {
		int x1 = (x+mx)&(r_width-1);
		FLOAT d = r[o];
		r[o] = r[y1+x1];
		r[y1+x1] = d;
		}
	   }

	return r;
	}

static void
SwapDiagonal(COMPLEX *c)
	{
	int width = c->width;
	int height = c->height;
	int my = height/2;
	int mx = width/2;
	int x,y,o,o1;
	FLOAT *real = c->real;
	FLOAT *imag = c->imag;
	FLOAT tmp;

	for(y=0; y<my; ++y)
	   for(x=0; x<width; ++x) {
		int x1 = (x+mx)&(width-1);
		int y1 = (y+my) & (height-1);
		o  = y * width + x;
		o1 = y1 * width + x1;
		tmp=real[o]; real[o] = real[o1]; real[o1] = tmp;
		tmp=imag[o]; imag[o] = imag[o1]; imag[o1] = tmp;
		}
	   }
	
static FLOAT *
make_IdealHighPass(int width, int height, double radius)
	{
	int x,y;
	int cx,cy,o;
	FLOAT *r = VMalloc(width * height * S_FLOAT);

	if (width != height) {
	   Print("Width/Height (%d,%d) not supported - FFT routines require square arrays!\n",width,height);
	   exit(1);
	   }

	cx = width/2;
	cy = height/2;

	for(o=0,y=-radius; y<=radius; ++y)
	   for(x=-radius; x<=radius; ++x,++o) {
		FLOAT d = sqrt(x*x + y*y);
		if (d<=radius) r[o] = 0;
		else r[o] = 1.0;
		}

	return r;
	}

/*
 * Make a lowpass filter mask from three params (r1,r2,p) where
 * r1 = the inner radius, mak value is 1.0 in this region
 * r2 = outer radius, mask value is 0 outside this radius
 * p = power for transition region r1 -> r2
 *   p = 1 means linear transition
 *   p = 2 means squared transition, etc
 */

static FLOAT *
make_LowPass(int width, int height, double r1, double r2, double p)
	{
	int x,y,o,m;
	int cx,cy;
	int npix = width * height;
	double rr = r2-r1;
	FLOAT *map = VMalloc(npix * S_FLOAT);
	FLOAT *tmp = VMalloc(npix * S_FLOAT);

	if (width != height) {
	   Print("Width/Height (%d,%d) not supported - FFT routines require square arrays!\n",width,height);
	   exit(1);
	   }

	cx = width/2;
	cy = height/2;

	for(o=0,y=-cy; y<cy; ++y)
	   for(x=-cx; x<cx; ++x,++o) {
		FLOAT d = sqrt(x*x + y*y);
		if (d<=r1) map[o] = 1.0;
		else if (d>=r2) map[o] = 0;
		else {
		   FLOAT v = 1.0 - (d-r1)/rr; 
		   if(v<0) v=0; if (v>1) v=1.0;
		   if (p>1 && v<1) v = pow(v,p);
		   map[o] = v;
		   }
		}

#if 0
	/****************************************************************************
	 * bit-reverse the filter so we don't have to do this in the fft routine    */

	// work out m as  width = 2^n
	m=0; while( (1<<m) < width) ++m;

	for(o=0; o<npix; o+=width) do_bit_reverse(map+o,m);

	transpose_matrix(map,tmp,width,height);
	for(o=0; o<npix; o+=width) do_bit_reverse(tmp+o,m);
	transpose_matrix(tmp,map,width,height);

	VFree(tmp);
	/****************************************************************************/
#endif

	return SwapDiagonal_R(map,width,height);
	}

