#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <strings.h>
#include <unistd.h>

#include "mpi.h"

#define MAX( a, b ) ( a > b ? a : b )
#define MIN( a, b ) ( a < b ? a : b )

#define BLOCK 1024

/* lcs 用的暫存陣列，用 globle variable 避免每次都要 malloc */
int *buf[2];

/* 讀入檔案，將傳入的指標指向內容，傳回檔案長度 */
int readfile( char *filename, char **s )
{
	int fd;
	ssize_t len;
	off_t size;
	struct stat st;

	/* 這種情況下應該用 non-buffer 的 open, read 會比較快 */
	fd = open( filename, O_RDONLY );
	fstat( fd, &st );
	size = st.st_size;
	*s = (char *)malloc(size);
	len = read( fd, *s, size );

	/*
	fprintf( stderr, "file size: %u\nread: %d\n", size, len );
	*/

	return size;
}

/* 一般化的 LCS, 可由上方及左方傳入計算所需的部份結果 */
int lcs( int len1, int *left, char *s1, int *right, int len2, int *top, char *s2, int *buttom )
{
	int i, j;
	int lastindex, currindex;

	/* 將傳來的資料抄到該去的地方 */
	for(i = 0; i < len2; i++) {
		buf[1][i] = top[i];
	}

	for( i = 0; i < len1; i++) {

		/* 每次只 allocate 兩行紀錄內容 */
		lastindex = (i & 1) ^ 1;
		currindex = (i & 1) ^ 0;

		/* 將第一 column 獨立出來處理 */
		if( s1[i] == s2[0] ) {
			if( i == 0 ) {
				buf[currindex][0] = 1;
			} else {
				buf[currindex][0] = left[i - 1] + 1;
			}
		} else {
			buf[currindex][0] = MAX( left[i], buf[lastindex][0]);
		}

		/* 標準的 lcs 計算 */
		for( j = 1; j < len2; j++) {
			if( s1[i] == s2[j] ) {
				buf[currindex][j] = buf[lastindex][j - 1] + 1;
			} else {
				buf[currindex][j] = MAX( buf[currindex][j - 1], buf[lastindex][j] );
			}
		}

		/* 要傳給右邊的 */
		right[i] = buf[currindex][len2 - 1];
	}

	/* 要傳給下面的 */
	for( j = 0; j < len2; j++ ) {
		buttom[j] = buf[currindex][j];
	}

	return buf[currindex][len2 - 1];
}

int ceil( int x, int y )
{
	int r = x / y;
	if( y * r == x ) {
		return r;
	}
	return r + 1;
}

int main( int argc, char *argv[])
{
	int tasks, iam;
	int len1, len2, len3, block1, block2, blen1, blen2, result, answer = -1;
	int i, j, k, ii, jj, m, n;
	char *s1, *s2, *s3;
	int *left, *top, *right, *buttom;
	MPI_Status mpist;


	MPI_Init(&argc, &argv);
	MPI_Comm_size(MPI_COMM_WORLD, &tasks);
	MPI_Comm_rank(MPI_COMM_WORLD, &iam);

	if( iam == 0 ) {
		/* iam == 0 時讀入檔案，並把長度與內容 broadcast 給大家 */

		/* len1 為第一個檔的長度，s1 則為內容 */
		len1 = readfile( argv[1], &s1 );
		len2 = readfile( argv[2], &s2 );

		/* 保證 s1 一定比 s2 短 */
		if( len2 < len1 ) {
			len3 = len1;
			s3 = s1;
			len1 = len2;
			s1 = s2;
			len2 = len3;
			s2 = s3;
		}

		MPI_Bcast( &len1, 1, MPI_INT, 0, MPI_COMM_WORLD );
		MPI_Bcast( &len2, 1, MPI_INT, 0, MPI_COMM_WORLD );

		MPI_Bcast( s1, len1, MPI_CHAR, 0, MPI_COMM_WORLD );
		MPI_Bcast( s2, len2, MPI_CHAR, 0, MPI_COMM_WORLD );

	} else {
		/* 其他人接 broadcast */
		MPI_Bcast( &len1, 1, MPI_INT, 0, MPI_COMM_WORLD );
		MPI_Bcast( &len2, 1, MPI_INT, 0, MPI_COMM_WORLD );

		s1 = (char *)malloc( len1 );
		s2 = (char *)malloc( len2 );

		MPI_Bcast( s1, len1, MPI_CHAR, 0, MPI_COMM_WORLD );
		MPI_Bcast( s2, len2, MPI_CHAR, 0, MPI_COMM_WORLD );
	}

	/* row 數等於 processor 數 */
	m = tasks;
	block1 = ceil( len1, m);

	/* 每 BLOCK 切一個 column, 並保證 column 一定比 row 多 */
	n = ceil( len2, BLOCK);
	if( n < m ) {
		n = m;
		block2 = ceil( len2, n);
	} else {
		block2 = BLOCK;
	}

	left = (int *)malloc( block1 * sizeof(int) );
	right = (int *)malloc(block1 * sizeof(int) );
	top = (int *)malloc( block2 * sizeof(int) );
	buttom = (int *)malloc( block2 * sizeof(int) );

	buf[0] = (int *)malloc( block2 * sizeof(int) );
	buf[1] = (int *)malloc( block2 * sizeof(int) );

	/* 斜切的座標 */
	for( j = 0; j < m + n - 1; j++) {
		for( i = 0;; i++) {

			/* 算出真正的座標 */
			if( j < n ) {
				ii = i;
				jj = j - i;
			} else {
				ii = j - n + 1 + i;
				jj = n -1 - i;
			}

			if( jj < 0 || ii >= m ) {
				break;
			}

			/* 第 ii row 給編號 ii 的 processor 做 */
			if( iam != ii ) {
				continue;
			}

			/*
			printf( "node %d -> %d, %d\n", iam, ii, jj );
			*/

			/* 算出待計算區域的 dimension */
			if( ii != m - 1 ) {
				blen1 = block1;
			} else {
				blen1 = len1 - ii * block1;
			}

			if( jj != n - 1 ) {
				blen2 =  block2;
			} else {
				blen2 = len2 - jj * block2;
			}
			
			/* 接收上一 row 傳來的結果 */
			if( ii == 0 ) {
				bzero( top, blen2 * sizeof(int) );
			} else {
				MPI_Recv( top, blen2, MPI_INT, ii - 1, 0, MPI_COMM_WORLD, &mpist );
			}

			/* 從左邊來的結果就在自己身上，直接複製過來就好 */
			if( jj == 0 ) {
				bzero( left, blen1 * sizeof(int) );
			} else {
				for( k = 0; k < blen1; k++) {
					left[k] = right[k];
				}
			}

			/* 計算部份的 LCS */
			result = lcs( blen1, left, s1 + ii * block1, right, blen2, top, s2 + jj * block2, buttom );

			/* 把結果送給下一 row 的 */
			if( ii != m - 1 ) {
				MPI_Send( buttom, blen2, MPI_INT, ii + 1, 0, MPI_COMM_WORLD);
			}

			/* 如果是最後一塊則 result 就是最後的答案 */
			if( ii == m - 1 && jj == n - 1 ) {
				answer = result;
			}
		}
	}

	/* 印出答案 */
	if( answer == result) {
		printf( "Answer: %d\n", answer );
	}

	free(left);
	free(right);
	free(top);
	free(buttom);
	free(buf[0]);
	free(buf[1]);
	
	MPI_Finalize();

	return 0;
}

