#pragma once
#include <iostream>
#include <iomanip>
#include <fstream>
#include <string>
#include <stdio.h>
#include <stdlib.h>
#include <vector>
#include "observation.h"

using namespace std;

class MM
{
private:
#pragma region members
	int N; //Number of states
	int M; //Number of observation symbols
	int T; // length of observation sequences
	string *states; //Array of possible states in HMM
	string *strings; //Array of possible words in HMM
	double **A; //a matrix for the forward procedure
	double **B; //b matrix for the forward procedure
	double *PI; //pi vector for the forward procedure
	vector<Observation> observations;
#pragma endregion members

public:
	MM(){};
	~MM();

#pragma region dynamic Array allocator
	// This code from http://www.codeproject.com/KB/cpp/arrayDinamic.aspx
	// Initializes a multidimensional array
	template <typename T> 
	T **AllocateDynamicArray( int nRows, int nCols)
	{
	      T **dynamicArray;

	      dynamicArray = new T*[nRows];
	      for( int i = 0 ; i < nRows ; i++ )
	      dynamicArray[i] = new T [nCols];

	      return dynamicArray;
	}
#pragma endregion dynamic Array allocator

#pragma region Initializers
	void initStates(int n){states = new string[n];}
	void initStrings(int m){strings = new string[m];}
	void initA(int n){A = AllocateDynamicArray<double>(n,n);}
	void initB(int n, int m){B = AllocateDynamicArray<double>(n,m);}
	void initPI(int n){PI = new double[n];}
#pragma endregion Initializers

#pragma region Set Methods
	void setN(int num){N = num;}
	void setM(int num){M = num;}
	void setT(int num){T = num;}
	void setStateAt(int index, string pos){states[index] = pos;} //pos ~ part of speech
	void setStringAt(int index, string word){strings[index] = word;}
	void setA(int row, int col, double num){A[row][col] = num;}
	void setB(int row, int col, double num){B[row][col] = num;}
	void setPI(int index, double num){PI[index] = num;}
#pragma endregion Set Methods

#pragma region Get Methods
	int getN(){return N;}
	int getM(){return M;}
	int getT(){return T;}
	string getStateAt(int index){return states[index];}
	string getStringAt(int index){return strings[index];}
	double getAindex(int row, int col){return A[row][col];}
	double getBindex(int row, int col){return B[row][col];}
	double getPIindex(int index){return PI[index];}
#pragma endregion Get Methods

#pragma region Print Methods
	void printN(){cout << "N: " << this->getN() << endl << endl;}
	void printM(){cout << "M: " << this->getM() << endl << endl;}
	void printT(){cout << "T: " << this->getT() << endl << endl;}
	void printStates()
	{
		cout << "States: ";
		for(int ii = 0; ii < this->getN(); ii ++)
		{
			cout << this->getStateAt(ii) << " ";
		}
		cout << endl << endl;
	}
	void printStrings()
	{
		cout << "Strings: ";
		for(int ii = 0; ii < this->getM(); ii ++)
		{
			cout << this->getStringAt(ii) << " ";
		}
		cout << endl << endl;
	}
	void printA()
	{
		cout << "A: " << endl;
		for(int row = 0; row < this->getN(); row++)
		{
			for(int col = 0; col < this->getN(); col++)
			{
				cout << right << setprecision(2) << setw(5) << this->getAindex(row, col) << " ";
			}
			cout << endl;
		}
		cout << endl << endl;
	}
	void printB()
	{
		cout << "B: " << endl;
		for(int row = 0; row < this->getN(); row++)
		{
			for(int col = 0; col < this->getM(); col++)
			{
				cout << right << setprecision(2) << setw(5) << this->getBindex(row, col) << " ";
			}
			cout << endl;
		}
		cout << endl << endl;
	}
	void printPI()
	{
		cout << "Pi: ";
		for(int ii = 0; ii < this->getN(); ii ++)
		{
			cout << this->getPIindex(ii) << " ";
		}
		cout << endl << endl;
	}
	void printAll()
	{
		this->printN();
		this->printM();
		this->printT();
		this->printStates();
		this->printStrings();
		this->printA();
		this->printB();
		this->printPI();
	}
	void printObservations()
	{
		int size = observations.size();
		for(int ii=0; ii<size; ii++)
		{
			observations[ii].print();
		}
	}

#pragma endregion Print Methods

#pragma region Read Files and initialize structures
	void readHMM(string filename)
	{
		ifstream infile;
		infile.open(filename.c_str());

		if(!infile){cout << "Could not read the file" << endl; return;}

		// Read N, M and T
		char tempN[3]; char tempM[3]; char tempT[3];

		infile >> tempN >> tempM >> tempT;

		this->setN(atoi(tempN));
		this->setM(atoi(tempM));
		this->setT(atoi(tempT));

		char temp[100]; // set temp variable to read from stream
		
		// Init and Fill in States array
		this->initStates(this->getN());

		for(int ii = 0; ii < this->getN(); ii ++)
		{
			infile >> temp;
			this->setStateAt(ii,temp);
		}

		// Init and Fill Words array
		this->initStrings(this->getM());
		for(int ii = 0; ii < this->getM(); ii++)
		{
			infile >> temp;
			this->setStringAt(ii,temp);
		}


		// Init and Fill A array
		infile >> temp; // skips line with a:
		this->initA(this->getN());
		for(int row = 0; row < this->getN(); row++)
		{
			for(int col = 0; col < this->getN(); col++)
			{
				infile >> temp;
				this->setA(row,col,atof(temp));
			}
		}

		// Init and Fill B array
		infile >> temp; // skips line with b:
		this->initB(this->getN(), this->getM());
		for(int row = 0; row < this->getN(); row++)
		{
			for(int col = 0; col < this->getM(); col++)
			{
				infile >> temp;
				this->setB(row,col,atof(temp));
			}
		}

		// Init and Fill PI array
		infile >> temp; // skips line with pi:
		this->initPI(this->getN());
		for(int ii = 0; ii < this->getN(); ii++)
		{
				infile >> temp;
				this->setPI(ii, atof(temp));
		}

		infile.close();
	}

	void readObservationFile(string filename)
	{
		ifstream infile;
		infile.open(filename.c_str());

		if(!infile){cout << "Could not read the file" << endl; return;}

		char temp[100];

		infile >> temp;

		int numObservations = atoi(temp);

		int numWords;
		string word;
		Observation *tempObs;

		for(int ii = 0; ii < numObservations; ii++)
		{
			infile >> temp;
			numWords = atoi(temp);

			tempObs = new Observation(numWords);

			for(int jj = 0; jj < numWords; jj++)
			{
				infile >> temp;
				word = temp;
				tempObs->addWord(word,jj);
			}

			observations.push_back(*tempObs);
		}

		infile.close();
	}
#pragma endregion Read Files and initialize structures

#pragma region finding greek letters
	int searchString(string str)
	{
		for(int ii=0; ii<this->M; ii++)
			if(str == strings[ii]) return ii;
		return -1;
	}

	double **findAlpha(int sentence)
	{
		double **alpha = AllocateDynamicArray<double>(this->N,observations[sentence].getSize());

		for(int xx=0; xx<observations[sentence].getSize(); xx++)for(int yy=0; yy<this->N; yy++)alpha[yy][xx] = 0;  // Clear Alpha

		//Initialization
		for(int ii=0; ii<this->N; ii++)
		{
			string word = observations[sentence].getWord(0);
			int col = this->searchString(word);
			alpha[ii][0] = this->getPIindex(ii)*this->getBindex(ii,col);
		}

		//Induction
		for(int tt=0; tt<observations[sentence].getSize()-1; tt++)
		{
			for(int jj=0; jj<this->N; jj++)
			{
				double sum = 0;
				for(int ii=0; ii<this->N; ii++)
				{
					double a = this->getAindex(ii,jj);
					double al = alpha[ii][tt];
					sum += al * a;
				}
				string word = observations[sentence].getWord(tt+1);
				int col = this->searchString(word);
				double b = this->getBindex(jj,col);
				alpha[jj][tt+1] = sum * b;
			}
		}

		return alpha;
	}

	double **findBeta(int sentence)
	{
		double **beta = AllocateDynamicArray<double>(this->N,observations[sentence].getSize());

		for(int xx=0; xx<observations[sentence].getSize(); xx++)for(int yy=0; yy<this->N; yy++)beta[yy][xx] = 0;  // Clear Beta

		//Initialization
		for(int ii=0; ii<this->N; ii++)
			beta[ii][observations[sentence].getSize()-1] = 1;

		//Inductions
		for(int tt=observations[sentence].getSize()-2; tt>=0; tt--)
		{
			for(int ii=0; ii<this->N; ii++)
			{
				double sum = 0;
				for(int jj=0; jj<this->N; jj++)
				{
					string word = observations[sentence].getWord(tt+1);
					int col = this->searchString(word);
					double b = this->getBindex(jj,col);
					sum += this->getAindex(ii,jj) * b * beta[jj][tt+1];
				}
				beta[ii][tt] = sum;
			}
		}

		return beta;
	}

	double ***findEta(int sentence, double probability)
	{
		// Create eta
		double ***eta = new double**[observations[sentence].getSize()-1];

		for(int ii=0; ii<this->observations[sentence].getSize()-1; ii++)
		{
			eta[ii] = new double*[this->N];
			for(int jj=0; jj<this->N; jj++)
			{
				eta[ii][jj] = new double[this->N];
				for(int kk=0; kk<this->N; kk++)
					eta[ii][jj][kk] = 0;
			}
		}
		double **alpha = this->findAlpha(sentence);
		double **beta = this->findBeta(sentence);

		for(int tt=0; tt<this->observations[sentence].getSize()-1; tt++)
		{
			for(int ii=0; ii<this->N; ii++)
			{
				for(int jj=0; jj<this->N; jj++)
				{
					double top = 0;
					string word = observations[sentence].getWord(tt+1);
					int col = this->searchString(word);
					double b = this->getBindex(jj,col);

					top = alpha[ii][tt] * this->getAindex(ii,jj) * b * beta[jj][tt+1];
					eta[tt][ii][jj] = top / probability;
				}
			}
		}

		return eta;
	}

	double **findGamma(int sentence, double probability)
	{
		double **gamma = AllocateDynamicArray<double>(this->N,this->T);

		for(int xx=0; xx<this->T; xx++)for(int yy=0; yy<this->N; yy++)gamma[yy][xx] = 0;  // Clear Gamma

		double **alpha = this->findAlpha(sentence);
		double **beta = this->findBeta(sentence);

		for(int tt=0; tt<this->observations[sentence].getSize(); tt++)
			for(int ii=0; ii<this->N; ii++)
			{
				gamma[ii][tt] = alpha[ii][tt] * beta[ii][tt] / probability;
			}

		return gamma;
	}
#pragma endregion finding greek letters

	void recognize(string hmm, string obs)
	{
		this->observations.empty();
		this->readHMM(hmm);
		this->readObservationFile(obs);

		double probability;

		for(int sentence=0; sentence<observations.size(); sentence++)
		{
			double **alpha = this->findAlpha(sentence);

			probability = 0;

			for(int ii=0; ii<this->N; ii++)
			{
				probability += alpha[ii][observations[sentence].getSize()-1];
			}
			cout << probability << endl;
		}
	}

	void statepath(string hmm, string obs)
	{
		this->observations.empty();
		this->readHMM(hmm);
		this->readObservationFile(obs);

		double probability;
		for(int sentence=0; sentence<observations.size(); sentence++)
		{
			double **alpha = this->findAlpha(sentence);

			probability = 0;

			for(int ii=0; ii<this->N; ii++)
				probability += alpha[ii][observations.size()-1];
			
			cout << probability << " ";

			if(probability > 0)
				for(int jj=0; jj<observations.size(); jj++)
				{
					double max = 0;
					int index = -1;
					for(int ii=0; ii<this->N; ii++)
					{
						double num = alpha[ii][jj];
						if(num >= max)
						{
							max = num;
							index = ii;
						}
					}
					string state = this->getStateAt(index);
					cout << state << " ";
				}

			cout << endl;
		}
	}

	void optimize(string hmm, string obs, string writeFile)
	{
		this->observations.empty();
		this->readHMM(hmm);
		this->readObservationFile(obs);
		double **alpha = this->findAlpha(0);

		double **a = AllocateDynamicArray<double>(this->N,this->N);
		for(int xx=0; xx<observations[0].getSize(); xx++)for(int yy=0; yy<this->N; yy++)a[yy][xx] = 0;  // Clear a

		double **b = AllocateDynamicArray<double>(this->N,this->M);
		for(int xx=0; xx<this->M; xx++)for(int yy=0; yy<this->N; yy++)b[yy][xx] = 0;  // Clear b

		double *pi = new double[this->N];

		double probability = 0;

		// Before:
		for(int ii=0; ii<this->N; ii++)
		{
			probability += alpha[ii][observations[0].getSize()-1];
		}
		cout << setprecision(5) << probability;
		// End Before

		double **beta = this->findBeta(0);
		double ***eta = this->findEta(0,probability);
		double **gamma = this->findGamma(0,probability);

		////////////////////////////////////// pi
		for(int ii=0; ii<this->N; ii++)
			pi[ii] =  gamma[ii][0];
		////////////////////////////////////// a
		for(int ii=0; ii<this->N; ii++)
		{
			for(int jj=0; jj<this->N; jj++)
			{
				double top = 0;
				double bottom = 0;
				for(int tt=0; tt<observations[0].getSize()-1; tt++)
				{
					top += eta[tt][ii][jj];
					bottom += gamma[ii][tt];
				}
				
				if(bottom != 0)
					a[ii][jj] = top/bottom;
			}
		}
		for(int ii=0; ii<this->N; ii++)
		{
			double sum = 0;
			for(int jj=0; jj<this->N; jj++)
				sum += a[ii][jj];

			if(sum == 0)
				for(int jj=0; jj<this->N;jj++)
					a[ii][jj] = this->getAindex(ii,jj);
		}
		////////////////////////////// b
		for(int jj=0; jj<this->N; jj++)
		{
			for(int kk=0; kk<this->M; kk++)
			{
				double top = 0;
				double bottom = 0;
				string word = this->getStringAt(kk);

				for(int ff=0; ff<this->observations[0].getSize(); ff++)
				{
					string word2 = observations[0].getWord(ff);
					if(word2 == word) top++; 
				}
				for(int tt=0; tt<this->observations[0].getSize(); tt++)
					bottom += gamma[jj][tt];

				if(this->getBindex(jj,kk) != 0)
					if(bottom>0) b[jj][kk] = top/bottom;
			}
		}
		for(int ii=0; ii<this->N; ii++)
		{
			double sum = 0;
			for(int jj=0; jj<this->M; jj++)
				sum += b[ii][jj];

			if(sum == 0)
				for(int jj=0; jj<this->M;jj++)
					b[ii][jj] = this->getBindex(ii,jj);
		}
		////////////////////////////// end matrices
		////////////////////////////// write file
		ofstream outfile;
		outfile.open(writeFile.c_str());
		outfile << this->N << " " << this->M << " " << this->T << endl;
		for(int ii=0; ii<this->N; ii++) outfile << this->getStateAt(ii) << " ";
		outfile << endl;
		for(int ii=0; ii<this->M; ii++) outfile << this->getStringAt(ii) << " ";
		outfile << endl;
		outfile << "a:" << endl;
		for(int ii=0; ii<this->N; ii++)
		{
			for(int jj=0; jj<this->N; jj++)
				outfile << a[ii][jj] << " ";
			outfile << endl;
		}
		outfile << "b:" << endl;
		for(int ii=0; ii<this->N; ii++)
		{
			for(int jj=0; jj<this->M; jj++)
				outfile << b[ii][jj] << " ";
			outfile << endl;
		}
		outfile << "pi:" << endl;
		for(int ii=0; ii<this->N; ii++) outfile << pi[ii] << " ";
	}

};


