/*
 * matrix multiplication using strassen method and traditional method
 *
 * author: En-ran, Zhou
 * date: 2003/09/19
 */

#include <iostream>
#include <math.h>

using namespace std;

const int maxsize = 32;
const int prec = 3;

class Matrix {
public:
	Matrix();
	Matrix(const int&);

	void setvalue(const int&, const int&, const double&);

	Matrix operator= (const Matrix&);
	friend Matrix operator+ (const Matrix&, const Matrix& );
	friend Matrix operator- (const Matrix&, const Matrix& );
	friend Matrix operator* (const Matrix&, const Matrix& ); // traditional
	friend Matrix operator/ (const Matrix&, const Matrix& ); // strassen
	friend ostream& operator<< (ostream& , const Matrix&);
private:
	int size;
	double m[maxsize][maxsize];
	void partition(Matrix& , Matrix& , Matrix& , Matrix& ) const;
	void merge(const Matrix& ,const Matrix& ,const Matrix& ,const Matrix& );
};

Matrix::Matrix()
{
	size = 0;
}

Matrix::Matrix(const int& n)
{
	int i, j;
	size = n;
	for(i = 0; i < size; i++) {
		for(j = 0; j < size; j++) {
			m[i][j] = 0;
		}
	}
}

void Matrix::setvalue(const int& x, const int& y, const double& v)
{
	m[x][y] = v;
}

Matrix Matrix::operator= (const Matrix& rhs)
{
	int i, j;
	size = rhs.size;

	for(i = 0; i < size; i++) {
		for(j = 0; j < size; j++) {
			m[i][j] = rhs.m[i][j];
		}
	}

	return *this;
}

ostream& operator<< (ostream& out, const Matrix &a)
{
	int i, j;
	cout.precision(prec);
	for(i = 0; i < a.size; i++) {
		for(j = 0; j < a.size; j++) {
			cout << "[" << a.m[i][j] << "]"; 
		}
		cout << endl;
	}
	return out;
}

void Matrix::partition(Matrix &a11, Matrix &a12, Matrix &a21, Matrix &a22) const
{
	int i, j;
	a11.size = a12.size = a21.size = a22.size = size / 2;

	for(i = 0; i < size / 2; i++) {
		for(j = 0; j < size / 2; j++) {
			a11.m[i][j] = m[i][j];
			a12.m[i][j] = m[i][j + size / 2];
			a21.m[i][j] = m[i + size / 2][j];
			a22.m[i][j] = m[i + size / 2][j + size / 2];
		}
	}

}

void Matrix::merge(const Matrix &a11,const Matrix &a12,const Matrix &a21,const Matrix &a22)
{
	int i, j;
	size = a11.size * 2;

	for(i = 0; i < size / 2; i++) {
		for(j = 0; j < size / 2; j++) {
			m[i][j] = a11.m[i][j];
			m[i][j + size / 2] = a12.m[i][j];
			m[i + size / 2][j] = a21.m[i][j];
			m[i + size / 2][j + size / 2] = a22.m[i][j];
		}
	}
}

Matrix operator+ (const Matrix &a, const Matrix &b)
{
	int i, j;
	Matrix tmp = a;
	for(i = 0; i < a.size; i++) {
		for(j = 0; j < a.size; j++) {
			tmp.m[i][j] += b.m[i][j];
		}
	}
	return tmp;
}

Matrix operator- (const Matrix &a, const Matrix &b)
{
	int i, j;
	Matrix tmp = a;
	for(i = 0; i < a.size; i++) {
		for(j = 0; j < a.size; j++) {
			tmp.m[i][j] -= b.m[i][j];
		}
	}
	return tmp;
}

Matrix operator* (const Matrix &a, const Matrix &b)
{
	Matrix tmp(a.size), a11, a12, a21, a22, b11, b12, b21, b22, m1, m2, m3, m4, m5, m6, m7;

	if(a.size == 1) {
		tmp.m[0][0] = a.m[0][0] * b.m[0][0];
		return tmp;
	}

	a.partition(a11, a12, a21, a22);
	b.partition(b11, b12, b21, b22);

	m1 = (a11 + a22) * (b11 + b22);
	m2 = (a21 + a22) * b11;
	m3 = a11 * (b12 - b22);
	m4 = a22 * (b21 - b11);
	m5 = (a11 + a12) * b22;
	m6 = (a21 - a11) * (b11 + b12);
	m7 = (a12 - a22) * (b21 + b22);

	tmp.merge( m1 + m4 - m5 + m7, m3 + m5, m2 + m4, m1 + m3 - m2 + m6);

	return tmp;
}

Matrix operator/ (const Matrix &a, const Matrix &b)
{
	int i, j, k;
	double s;
	Matrix tmp(a.size);

	for(i = 0; i < a.size; i++) {
		for(j = 0; j < a.size; j++) {
			s = 0;
			for(k = 0; k < a.size; k++) {
				s += a.m[i][k] * b.m[k][j];
			}
			tmp.m[i][j] = s;
		}
	}

	return tmp;
}

main()
{
	Matrix a(2), b(2), c, d;

	a.setvalue(0, 0, 1.1);
	a.setvalue(0, 1, 1. / 3);
	a.setvalue(1, 0, M_PI);
	a.setvalue(1, 1, sqrt(2.));

	b.setvalue(0, 0, 1. / 7);
	b.setvalue(0, 1, 1. / 8);
	b.setvalue(1, 0, 1. / 9);
	b.setvalue(1, 1, 1. / 10);

	c = a * b;
	d = a / b;

	cout << "Strassen method" << endl;
	cout << c << endl;
	cout << "Traditional method" << endl;
	cout << d << endl;
}

