#include "ninox.h"

static double *Rdata = NULL;  // This is our buffer
static double *Gdata = NULL;  // This is our buffer
static double *Bdata = NULL;  // This is our buffer
static int    *Count = NULL;  // array of counters, one per pixel

static int Frames =0;
static int NPixels=0;
static int Width,Height,Depth;
static int Bpp=0;
static int Dstwidth, Dstheight;

// must call this between stacking runs
int
stack_init(void)
	{
	if (Rdata) {free(Rdata); Rdata = NULL;}
	if (Gdata) {free(Gdata); Gdata = NULL;}
	if (Bdata) {free(Bdata); Bdata = NULL;}
	if (Count) {free(Count); Count = NULL;}
	Bpp=NPixels=Frames=0;
	}

int
stack_frame(struct Image *img)
	{
	int width = img->width;
	int height = img->height;
	int depth = img->depth;
	unsigned char *data = img->data;
	int i,b,g,r,o;
	int threshhold = 1;  // anything not exactly black is ok
	unsigned short *sdata = (unsigned short *)data;
	double v;

	if (depth != 16 && depth != 8 && depth != 24) {
	   Print("-stack only supported for 8/16/24 bpp data\n");
	   return 0;
	   }

	if (Rdata == NULL) {
	   // First frame, allocate buffer
	   Width = width; Height=height; Depth = depth;
	   Bpp = Depth/8;
	   Dstwidth = img->dst_width;
	   Dstheight = img->dst_height;

	   NPixels = width * height;
	   Rdata = (double *)ZeroMalloc(sizeof(double) * NPixels);
	   Gdata = (double *)ZeroMalloc(sizeof(double) * NPixels);
	   Bdata = (double *)ZeroMalloc(sizeof(double) * NPixels);
	   Count = (int *)   ZeroMalloc(sizeof(int) * NPixels);
	   }
	else // Check that image params have not changed
	   if (Depth != depth || Width != width || Height != height) {
		Print("stack_frame: Image data has changed! Now (%dx%dx%d), was (%dx%dx%d).\n",
			width,height,depth,Width,Height,Depth);
		exit(0);
		}

	// Stack the frame
	switch(depth) {
	   case 8:
		for(i=0; i<NPixels; ++i) {
	   	   v = data[i];
		   if (v>=threshhold) {
	   	      Rdata[i] += v * 256 ; // upgrade 8 -> 16 bpp, use R channel only
		      Count[i]++;
		      }
	   	   }
		break;
	   case 16:
		threshhold <<= 8;
		for(i=0; i<NPixels; ++i) {
	   	   v = sdata[i];
		   if (1 || v>=threshhold) {
	   	      Rdata[i] += v;	// 16bpp data, use R channel only
		      Count[i]++;
		      }
	   	   }
		break;
	   case 24:
		for(i=o=0; i<NPixels; ++i,o+=3) {
	   	   b = data[o]; g = data[o+1]; r = data[o+2];
	   	   if (b*0.114 + g*0.587 + r*0.299 >= threshhold) {
	   	      Bdata[i] += b*256;	// 8bpp data, B channel, upgrade to 16bpp
	   	      Gdata[i] += g*256;	// 8bpp data, G channel, upgrade to 16bpp
	   	      Rdata[i] += r*256;	// 8bpp data, R channel, upgrade to 16bpp
		      Count[i]++;
		      }
	   	   }
		break;
	   default:
		Print("Stack: Depth %d unsupported\n",depth);
		break;
	   }
	return ++Frames;
	}

static int
fetch_stack_u32(u32 *R, u32 *G, u32 *B)
	{
	int i;
	u32 v;

	if (! Frames) {
	   Print("fetch_stack_u32: no frames stacked\n");
	   return 0;
	   }

	Print("Averaging %d frames\n",Frames);

	switch(Depth) {
	   case 24:
		for(i=0; i<NPixels; ++i) if (Count[i]) {
		   G[i] = Gdata[i]*65536.0 / Count[i];
	   	   B[i] = Bdata[i]*65536.0 / Count[i];
	   	   }
		else { G[i] = B[i] = 0; }
		// Let this fall through to do the red channel
	   case 8:
	   case 16:
		for(i=0; i<NPixels; ++i) if (Count[i]){
	   	   R[i] = Rdata[i]*65536.0 / Count[i];
	   	   }
		else { R[i] = 0; }
		break;
	   }

	return 1;
	}

static int
fetch_stack_u16(unsigned short *R, unsigned short *G, unsigned short *B)
	{
	int i;
	unsigned int v;

	if (! Frames) {
	   Print("fetch_stack_u16: no frames stacked\n");
	   return 0;
	   }

	Print("Averaging %d frames\n",Frames);

	switch(Depth) {
	   case 24:
		for(i=0; i<NPixels; ++i) if (Count[i]) {
		   v = Gdata[i] / Count[i];
	   	   if (v>65535) { v=65535; Print("Warning: truncated data\n"); }
	   	   G[i] = v;

	   	   v = Bdata[i] /= Count[i];
	   	   if (v>65535) { v=65535; Print("Warning: truncated data\n"); }
	   	   B[i] = v;
	   	   }
		else { G[i] = B[i] = 0; }
		// Let this fall through to do the red channel
	   case 8:
	   case 16:
		for(i=0; i<NPixels; ++i) if (Count[i]){
	   	   v = Rdata[i] / Count[i];
	   	   if (v>65535) { v=65535; Print("Warning: truncated data\n"); }
	   	   R[i] = v;
	   	   }
		else { R[i] = 0; }
		break;
	   }

	return 1;
	}

int
write_stack_file(char *fname)
	{
	FILE *out;
	double bscale = 1.0;
	int bzero = 0;
	int i,x;
	u32 *R,*G,*B,*ptr;
	struct Image *img;
	char new_fname[1024],*p;

	// Allocate the buffer
	R = (u32 *) Malloc(Width * Height * 4);
	G = (u32 *) Malloc(Width * Height * 4);
	B = (u32 *) Malloc(Width * Height * 4);

	// Fetch the image
	fetch_stack_u32(R,G,B);
	
	strcpy(new_fname,fname);
	p = new_fname + strlen(new_fname);
	while(p != new_fname && *p != '.') --p;

	// If we can't find a suffix then set up to append
	if (p == new_fname) p = new_fname + strlen(new_fname);

	switch(Depth) {
	   case 8:
	   case 16:
		// Monochrome (R channel) data only. Write out as a 32BPP FITS floating point
		strcpy(p,".fit");
		img = CreateFitsImage(new_fname,Width,Height,32,(unsigned char *)R,bzero,bscale);
		if (!img) {
	   	   Print("write_stack_file: Error creating output img '%s'\n",fname);
	   	   exit(1);
	   	   }
		img->dst_width = Dstwidth; img->dst_height = Dstheight;

		// Convert from 32bpp unsigned to 32bpp floating point if requested
		if (OutputFileDepth == -32)
		   img = ConvertToType(img, IMG_FIT, -32);

		if (! img) {
		   Print("write_stack_file: error converting from U32 to F32\n");
		   exit(1);
		   }

        	if (! WriteImage(img)) {
           	   Print("Short write on output to %s\n",fname);
           	   exit(1);
           	   }
		Print("created stack file %dx%dx%d\n",img->dst_width,img->dst_height,img->depth);
		DestroyImage(img);
		break;
	   case 24:
		// Create separate R,G,B 16bpp FITS
		strcpy(p,"_R.fit");
		img = CreateFitsImage(new_fname,Width,Height,16,(unsigned char *)R,bzero,bscale);
		if (!img) {
	   	   Print("write_stack_file: Error creating output img '%s'\n",new_fname);
	   	   exit(1);
	   	   }
		img->dst_width = Dstwidth; img->dst_height = Dstheight;
        	if (! WriteImage(img)) {
           	   Print("Short write on output to %s\n",new_fname);
           	   exit(1);
           	   }
		DestroyImage(img);  // also frees R

		strcpy(p,"_G.fit");
		img = CreateFitsImage(new_fname,Width,Height,16,(unsigned char *)G,bzero,bscale);
		if (!img) {
	   	   Print("write_stack_file: Error creating output img '%s'\n",new_fname);
	   	   exit(1);
	   	   }
		img->dst_width = Dstwidth; img->dst_height = Dstheight;
        	if (! WriteImage(img)) {
           	   Print("Short write on output to %s\n",new_fname);
           	   exit(1);
           	   }
		DestroyImage(img); // also frees G

		strcpy(p,"_B.fit");
		img = CreateFitsImage(new_fname,Width,Height,16,(unsigned char *)B,bzero,bscale);
		if (!img) {
	   	   Print("write_stack_file: Error creating output img '%s'\n",new_fname);
	   	   exit(1);
	   	   }
		img->dst_width = Dstwidth; img->dst_height = Dstheight;
        	if (! WriteImage(img)) {
           	   Print("Short write on output to %s\n",new_fname);
           	   exit(1);
           	   }
		DestroyImage(img); // also frees B

		break;
	      }

	return(1);
	}

//====================================================================================

static struct Image *
load_mergefile(char *fname)
	{
	FILE *in;
	int i,x,count=0;
	double avg;
	int minval8 = ThreshHold;
	int minval16 = minval8<<8 ; // don't process the background
	static struct Image *MImg = NULL;
	unsigned char *ptr;
	unsigned short *uptr;

	// Image already loaded
	if (MImg) {
	   if (!strcmp(MImg->src_fname,fname)) return MImg;
	   Print("Warning: load_mergefile: changing reference image\n");
	   DestroyImage(MImg);
	   }

	MImg = LoadImage(fname,"<internal>");

	if (!MImg) {
	   Print("load_mergefile: cannot open '%s' for reading\n",fname);
	   return NULL;
	   }

	if (MImg->depth != 8 && MImg->depth != 16) {
           Print("load_mergefile: Unsupported depth: %d\n",MImg->depth);
	   DestroyImage(MImg);
           return NULL;
           }

	i = MImg->width * MImg->height;
	ptr = MImg->data;
	uptr = (unsigned short *)MImg->data;

	if (MImg->depth == 8) while(i--) {
             if (*ptr >= minval8) {avg += *ptr; ++count;}
	     ++ptr;
             }
	else if (MImg->depth == 16) while(i--) {
             if (*uptr >= minval16) {avg += *uptr; ++count;}
	     ++uptr;
             }

	MImg->m.avg = (double)avg / count;
	Print("[LoadMerge%d avg=%lf] ",MImg->depth,MImg->m.avg);

	return MImg;
	}

#define MERGE_EQ(a,b) (((a) + ((a)>>1) + ((b)>>1))>>1)

static void
merge_16_16(unsigned short *src, struct Image *r, int npix)
	{
	unsigned short *ref = (unsigned short *)r->data;
	int i,count,high_count=0,low_count=0;
	int minval = ThreshHold << 8;
	double avg,d,scale;

	// calculate average brightness for scaling purposes
	for(i=avg=count=0; i<npix; ++i)
	   if (src[i] >= minval) {avg += src[i]; ++count;}
		
	if (count < MinPixels) {
	   Print("merge_data: found no significant pixels\n");
	   exit(1);
	   }

	avg /= (double) count;
	scale = r->m.avg/avg;

	for(i=0; i<npix; ++i,++src,++ref) {
	   int s = (double)*src * scale + 0.5;
	   if (s >= minval && *ref) {
		double d = s - *ref; d /= *ref;

		if (d >= MergeThreshHold) {
		   *src = MERGE_EQ(s,*ref); ++high_count; }
		else if (d <= -MergeThreshHold) {
		   *src = MERGE_EQ(s,*ref); ++low_count; }
		}
	   }

	Print("%d/%d pixels] ",high_count,low_count);
	return;
	}

static void
merge_8_8(unsigned char *src, struct Image *r, int npix)
	{
	unsigned char *ref = r->data;
	int i,count,high_count=0,low_count=0;
	int minval = ThreshHold;
	double avg,d,scale;

	// calculate average brightness for scaling purposes
	for(i=avg=count=0; i<npix; ++i)
	   if (src[i] >= minval) {avg += src[i]; ++count;}
		
	if (count < MinPixels) {
	   Print("merge_data: found no significant pixels\n");
	   exit(1);
	   }

	avg /= (double) count;
	scale = r->m.avg/avg;

	for(i=0; i<npix; ++i,++src,++ref) {
	   int s = (double)*src * scale + 0.5;
	   if (s >= minval && *ref) {
		double d = s - *ref; d /= *ref;

		if (d >= MergeThreshHold) {
		   *src = MERGE_EQ(s,*ref); ++high_count; }
		else if (d <= -MergeThreshHold) {
		   *src = MERGE_EQ(s,*ref); ++low_count; }
		}
	   }

	Print("%d/%d pixels] ",high_count,low_count);
	return;
	}

static void
merge_16_8(unsigned short *src, struct Image *r, int npix)
	{
	unsigned char *ref = r->data;
	int i,count,high_count=0,low_count=0;
	int minval = ThreshHold << 8;
	double avg,d,scale;

	// calculate average brightness for scaling purposes
	for(i=avg=count=0; i<npix; ++i)
	   if (src[i] >= minval) {avg += src[i]>>8; ++count;}
		
	if (count < MinPixels) {
	   Print("merge_data: found no significant pixels\n");
	   exit(1);
	   }

	avg /= (double) count;
	scale = r->m.avg/avg;

	for(i=0; i<npix; ++i,++src,++ref) {
	   int s = (double)*src * scale + 0.5;
	   int r = *ref << 8;

	   if (s >= minval && r) {
		double d = (s - r) / r;

		if (d >= MergeThreshHold) {
		   *src = MERGE_EQ(s,r); ++high_count; }
		else if (d <= -MergeThreshHold) {
		   *src = MERGE_EQ(s,r); ++low_count; }
		}
	   }

	Print("avg=%lf scale=%lf %d/%d pixels] ",avg,scale,high_count,low_count);
	return;
	}

static void
merge_8_16(unsigned char *src, struct Image *r, int npix)
	{
	unsigned short *ref = (unsigned short *)r->data;
	int i,count,high_count=0,low_count=0;
	int minval = ThreshHold;
	double avg,d,scale;

	// calculate average brightness for scaling purposes
	for(i=avg=count=0; i<npix; ++i)
	   if (src[i] >= minval) {avg += src[i]<<8; ++count;}
		
	if (count < MinPixels) {
	   Print("merge_data: found no significant pixels\n");
	   exit(1);
	   }

	avg /= (double) count;
	scale = r->m.avg/avg;

	for(i=0; i<npix; ++i,++src,++ref) {
	   int s = (double)*src * scale + 0.5;
	   int r = *ref >> 8;

	   if (s >= minval && r) {
		double d = (s - r) / r;

		if (d >= MergeThreshHold) {
		   *src = MERGE_EQ(s,r); ++high_count; }
		else if (d <= -MergeThreshHold) {
		   *src = MERGE_EQ(s,r); ++low_count; }
		}
	   }

	Print("%d/%d pixels] ",high_count,low_count);
	return;
	}

int
merge_data(char *mfile, struct Image *img)
	{
	int width = img->width;
	int height = img->height;
	int depth = img->depth;
	unsigned char *data = img->data;
	struct Image *ref;
	int npix;

	ref = load_mergefile(mfile);

	if (!ref || (width != ref->width) || height != ref->height) {
	   Print("merge_data: Frame size does not match mergefile\n");
	   return 0;
	   }

	npix = width * height;
	Print("[merge %d+%d, ",depth,ref->depth);
	if (depth==16) {
	   if (ref->depth==16)     merge_16_16((unsigned short *)data, ref,npix);
	   else if (ref->depth==8) merge_16_8((unsigned short *)data, ref,npix);
	   }
	else if (depth==8) {
	   if (ref->depth==16) 	   merge_8_16(data, ref,npix);
	   else if (ref->depth==8) merge_8_8(data, ref,npix);
	   }

	return(1);
	}
