#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, unicode_literals
from math import log
from numpy.random import exponential
from random import random
import sys

def main(input_filepath):
	real_data = read_data(input_filepath)
	real_results = analyze_data(real_data)
	real_lrt, real_null_mles, real_alt_mles = real_results
	num_sims = 100
	null_dist = simulate_null_dist_of_lrt(real_data, real_null_mles, num_sims)
	summarize_results(real_results, null_dist)

def analyze_data(data):
	with_summary_stats = calc_summaries(data)
	alt_mle = calc_MLE_alt(with_summary_stats)
	alt_ln_like = calc_ln_likelihood(with_summary_stats, alt_mle)
	null_mle = calc_MLE_null(with_summary_stats)
	null_ln_like = calc_ln_likelihood(with_summary_stats, null_mle)
	lrt = 2*(alt_ln_like - null_ln_like)
	return lrt, null_mle, alt_mle

def calc_MLE_null(data):
	'''WRONG'''
	Wait_times, Starts, Switches, Intervals, MeanWaits = data
	nx = float(Starts["L"]+Starts["M"]+Starts["H"])
	pStart = {"L":float(Starts["L"])/nx,"M":float(Starts["M"])/nx,"H":float(Starts["H"])/nx}
	betas = {"L":1.0/MeanWaits["L"],"M":1.0/MeanWaits["M"],"H":1.0/MeanWaits["H"]}
	pMove = {"LM":float(Switches["LM"])/float(Switches["LM"]+Switches["LH"]),"MH":float(Switches["MH"])/float(Switches["MH"]+Switches["ML"]),"HL":float(Switches["HL"])/float(Switches["HL"]+Switches["HM"])}
	pMove["LH"] = 1.0-pMove["LM"]
	pMove["ML"] = 1.0-pMove["MH"]
	pMove["HM"] = 1.0-pMove["HL"]
	return pStart, betas, pMove

def calc_MLE_alt(data):
	Wait_times, Starts, Switches, Intervals, MeanWaits = data
	nx = float(Starts["L"]+Starts["M"]+Starts["H"])
	pStart = {"L":float(Starts["L"])/nx,"M":float(Starts["M"])/nx,"H":float(Starts["H"])/nx}
	betas = {"L":1.0/MeanWaits["L"],"M":1.0/MeanWaits["M"],"H":1.0/MeanWaits["H"]}
	pMove = {"LM":float(Switches["LM"])/float(Switches["LM"]+Switches["LH"]),"MH":float(Switches["MH"])/float(Switches["MH"]+Switches["ML"]),"HL":float(Switches["HL"])/float(Switches["HL"]+Switches["HM"])}
	pMove["LH"] = 1.0-pMove["LM"]
	pMove["ML"] = 1.0-pMove["MH"]
	pMove["HM"] = 1.0-pMove["HL"]

	print("ML estimates")
	print("Start probs ",pStart)
	print("betas ",betas)
	print("Move probs ",pMove)
	return pStart, betas, pMove

def calc_ln_likelihood(data, parameters):
	Wait_times, Starts, Switches, Intervals, MeanWaits = data
	pStart, betas, pMove = parameters
	ll = 0.0
	for State in Starts: # state = "L" or "M" or "H"
		if Starts[State]>0:
			ll += Starts[State]*log(pStart[State]) # log-likelihood of starts
		if Intervals[State]>0:
			ll += float(Intervals[State])*( log(betas[State]) - MeanWaits[State]*betas[State] )

	for Transition in Switches: # Transition = "LM" or "LH" or "MH" or "ML" or"HL" or "HM"
		if pMove[Transition] > 0.0:
			ll += float(Switches[Transition])*log(pMove[Transition]) 
	return ll

def calc_summaries(data):
	print('data', data)
	Wait_times, Starts, Switches = data
	Intervals = {"L":len(Wait_times["L"]),"M":len(Wait_times["M"]),"H":len(Wait_times["H"])}
	MeanWaits = {"L":sum(Wait_times["L"])/float(len(Wait_times["L"])),"M":sum(Wait_times["M"])/float(len(Wait_times["M"])),"H":sum(Wait_times["H"])/float(len(Wait_times["H"]))}
	return Wait_times, Starts, Switches, Intervals, MeanWaits

def summarize_results(real_results, null_dist):
	''' Taken from http://phylo.bio.ku.edu/sites/default/files/logistic_regress_with_param_boot.py.txt'''
	real_lrt, real_null_mles, real_alt_mles = real_results
	null_dist.sort()
	num_sims = len(null_dist)
	five_percent_count = num_sims / 20.0
	p = 0.95
	print('Based on a simulated null distribution of for LRT:')
	index_cutoff = five_percent_count
	fmt = ' for P of about {:3.2f} the critical value is {:6.3f} (based on the {} out of {} values)'
	while p > 0.0:
		cutoff = null_dist[int(index_cutoff)]
		print(fmt.format(p, cutoff, index_cutoff, num_sims))
		index_cutoff += five_percent_count
		p -= 0.05
	p_value = len([i for i in null_dist if i > real_lrt])/float(num_sims)
	print('The approximation of the P-value is', p_value)

def read_data(fn):
	data_file = open(fn, 'rU')
	Wait_times = {"L":[],"M":[],"H":[]}
	Starts = {"L":0,"M":0,"H":0}
	Switches = {"LM":0,"LH":0,"ML":0,"MH":0,"HL":0,"HM":0}
	for idx, row in enumerate(data_file):
		cols = row.replace('\n', '').split('\t')
		time = float(cols[2])
		state = cols[1]
		bird = cols[0]
		Wait_times[state].append(time)

		if idx == 0:
			current_bird = bird
			current_state = state
			Starts[state] += 1

		if idx>0:
			if bird == current_bird: # Am I the same bird?
				Switches[current_state+state] += 1
				current_state = state
			else:
				Starts[state] += 1
				current_bird = bird
				current_state = state
	return Wait_times, Starts, Switches

def simulate_data_set(real_data, theta):
	'''Use real_data to get the size of the data. Use `theta`
	as the set of parameters to simulate under.
	'''
	Wait_times = {"L":[],"M":[],"H":[]}
	Starts = {"L":0,"M":0,"H":0}
	Switches = {"LM":0,"LH":0,"ML":0,"MH":0,"HL":0,"HM":0}
	for ...:
		datum = exponential(mean)
		Wait_times.append(datum)
	for ...:
		u = random()
		if u < prob_h:
			Starts['H'] += 1
		elif u < prob_h + prob_m:
			Starts['M'] += 1
		else:
			Starts['L'] += 1

	return Wait_times, Starts, Switches

def simulate_null_dist_of_lrt(real_data, theta, num_sims):
	null_dist = []
	for i in xrange(num_sims):
		sim_data = simulate_data_set(real_data, theta)
		sim_results = analyze_data(sim_data)
		sim_lrt = sim_results[0]
		null_dist.append(sim_lrt)
	return null_dist

main(sys.argv[1])


"""

# MLE equations

#Log-likelihood at MLE

print ("LnL = ",ll)

print ("In terms of the Q matrix terms")
print ("aLM" , pMove["LM"]*betas["L"])
print ("aLH" , (1-pMove["LM"])*betas["L"])
print ("aMH" , pMove["MH"]*betas["M"])
print ("aML" , (1-pMove["MH"])*betas["M"])
print ("aHL" , pMove["HL"]*betas["H"])
print 	("aHM" , (1-pMove["HL"])*betas["H"])


"""