#include <stdio.h>
#include <stdlib.h>

#include "fft.h"

#define BREAKEVEN 16

/*
 * fast 2D matrix transposition.from src[] -> dst[]
 *
 * Assume dimensions are both power of 2
 *
 * recursively decimate into smaller transpositions to help
 * cache coherency
 */

/* here's the simple implementation to use as our benchmark
 * pass f[n*m] array of floats
 */

static inline void simple_transpose(FLOAT *src, FLOAT *dst, int n, int m,int width);
static void fast_transpose(int lvl, FLOAT *src, FLOAT *dst, int n, int m, int width);
static void print_matrix(FLOAT *f, int n, int m);

static inline void
simple_transpose(FLOAT *src, FLOAT *dst, int n, int m,int width)
	{
	int j;
	int src_o = (m-1) * width - 1;
	int dst_o = width - m + 1;

	while(--n) {
	   j=m; while(--j) { *(dst++) = *src; src += width; }
	   *dst = *src;
	   dst += dst_o;
	   src -= src_o;
	   }

	while(--m) { *(dst++) = *src; src += width; }
	*dst = *src;
	}

static void
fast_transpose(int lvl, FLOAT *src, FLOAT *dst, int n, int m, int width)
	{
	int h = width*m;
	int n2=n>>1;
	int m2=m>>1;

	h>>=1;

	if (lvl==0) {
	   simple_transpose(src,dst,n2,m2,width);
	   simple_transpose(src+h,dst+n2,n2,m2,width);
	   simple_transpose(src+n2,dst+h,n2,m2,width);
	   simple_transpose(src+h+n2,dst+h+n2,n2,m2,width);
	   }
	else {
	   fast_transpose(lvl-1,src,dst,n2,m2,width);
	   fast_transpose(lvl-1,src+h,dst+n2,n2,m2,width);
	   fast_transpose(lvl-1,src+n2,dst+h,n2,m2,width);
	   fast_transpose(lvl-1,src+h+n2,dst+h+n2,n2,m2,width);
	   }
	}
	

static void
print_matrix(FLOAT *f, int n, int m)
	{
	int i,j,o;

	printf("\n");
	for(j=o=0; j<m; ++j,printf("\n"))
	   for(i=0; i<n; ++i,++o) printf("%f ",f[o]);
	printf("\n\n");
	}

void
transpose_matrix(FLOAT *src, FLOAT *dst, int width, int height)
	{
	int min = width<height ? width : height;

	if (min<=BREAKEVEN) simple_transpose(src,dst,width,height,width);
	else {
	   int l=0;
	   while(min>BREAKEVEN) { l++; min>>=1; }
	   fast_transpose(l,src,dst,width,height,width);
	   }
	}

#ifdef MAIN

#define N 8
#define M 8

int
main(int argc, char *argv[])
	{
	FLOAT *s = (FLOAT *)VMalloc(N*M*sizeof(FLOAT));
	FLOAT *d = (FLOAT *)VMalloc(N*M*sizeof(FLOAT));
	int i,j,o,width;
	long st = time(0);

#if 0
	if (argc>1 && ! strcmp(argv[1],"simple")) 
	for(i=0; i<5000; ++i) {
	   simple_transpose(s,d,N,M,N);
	   simple_transpose(d,s,M,N,N);
	   simple_transpose(s,d,N,M,N);
	   simple_transpose(d,s,M,N,N);
	   }
	else
	for(i=0; i<5000; ++i) {
	   fast_transpose(3,s,d,N,M,N);
	   fast_transpose(3,d,s,M,N,M);
	   fast_transpose(3,s,d,N,M,N);
	   fast_transpose(3,d,s,M,N,M);
	   }
#endif

	width = N;
	for(i=0; i<N*M; ++i) s[i]=i;
	print_matrix(s,N,M);
	simple_transpose(s,d,N,M,width);
	print_matrix(d,M,N);

	printf("time %d\n",time(0)-st);

	}
#endif
