#!/usr/bin/env python 
'''
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import random
from math import log
from copy import deepcopy
import sys
import random

# model 0 = standard normal
# model 1 = normal with mean mu and sigma= 1
# model 2 = mu, and sigma free

# Exponential prior on sigma
sigma_prior_hazard_param = 1.0
# Normal(0, 10) prior on mu
mu_prior_mean = 0.0
mu_prior_sigma = 10.0
twice_mu_prior_var = 2*mu_prior_sigma*mu_prior_sigma

#
mu_window = 1.0
sigma_window = 1.0
###############################
# You WILL NEED to modify the functions in this section...
def get_initial_parameters(command_line_args):
    '''As written, this will just return a floating point number for every argument
    sent in.
    This template strips of the initial data file name and # of 
    iterations and then calls this function to process the rest of the arguments.
    '''
    ini_params = [float(i) for i in command_line_args]
    ini_model = 2
    ini_params.append(ini_model)
    return ini_params

def calc_ln_prior(theta):
    mu, sigma, model_index = theta
    if model_index == 0:
        return 0.0
    ln_mu_prior = - ((mu - mu_prior_mean)**2)/twice_mu_prior_var
    if model_index == 1:
        return ln_mu_prior
    return ln_mu_prior - sigma_prior_hazard_param*sigma

def propose_a_new_theta(theta_to_alter):
    model_index = theta_to_alter[2]
    ln_hastings_ratio = 0.0
    move_choice_u = random.random()
    if model_index == 0:
        raise NotImplemented('Model 0')
    elif model_index == 1:
        if move_choice_u < 0.33333333333333333:
            u = random.random()
            sigma_star = -log(1 - u)/sigma_prior_hazard_param
            theta_to_alter[1] = sigma_star
            theta_to_alter[2] = 2 # Jump to model 2
            ln_hastings_ratio = -log(sigma_prior_hazard_param) + sigma_prior_hazard_param*sigma_star
            return theta_to_alter, ln_hastings_ratio
        else:
            prop_mu = True
    else:
        if move_choice_u < 0.33333333333333333:
            sigma = theta_to_alter[1]
            ln_hastings_ratio = log(sigma_prior_hazard_param) - sigma_prior_hazard_param*sigma
            theta_to_alter[1] = 1.0
            theta_to_alter[2] = 1 # Jump to model 1
            return theta_to_alter, ln_hastings_ratio
        elif move_choice_u < 0.66666666666666666667:
            prop_mu = True
        else:
            prop_mu = False
    if prop_mu:
        # change mu
        u = random.random() - 0.5
        diff = mu_window * u
        theta_to_alter[0] += diff
    else:
        # change sigma
        u = random.random() - 0.5
        diff = sigma_window * u
        theta_to_alter[1] += diff
        if theta_to_alter[1] < 0.0:
            theta_to_alter[1] = -theta_to_alter[1]
    return theta_to_alter, ln_hastings_ratio

def write_header(out):
    param_names = ['mu', 'sigma']
    out.write('iteration\tlnPosterior\tlnLikelihood\tlnPrior\t{}\tModelIndex\n'.format('\t'.join(param_names)))

###############################
# you MIGHT need to modify the arguments in this section (if you)
#   decide to represent the parameters (theta) as something other than
#    a list of numbers.

def sample_chain(out, step_index, ln_posterior, ln_likelihood, ln_prior, theta):
    p = '\t'.join([str(i) for i in theta])
    line = '{i}\t{x}\t{l}\t{r}\t{p}\n'.format(i=step_index,
                                              x=ln_posterior, 
                                              l=ln_likelihood,
                                              r=ln_prior,
                                              p=p)
    sys.stdout.write(line)
    out.write(line)

def read_data(fn):
    d = []
    with open(fn, 'rU') as inp:
        for line in inp:
            ls = line.strip()
            if not ls:
                continue
            d.append(float(ls))
    return d

def calc_ln_likelihood(theta, data):
    return 0.0

def mcmcmc(initial_theta, data, num_chains, out):
    assert num_chains > 0
    theta = deepcopy(initial_theta)
    ln_likelihood = calc_ln_likelihood(theta, data)
    ln_prior = calc_ln_prior(theta)
    ln_posterior = ln_likelihood + ln_prior
    # This is MCMC using the Metropolis algorithm:
    write_header(out)
    # Thinning - Sample only every 100 step
    sample_freq = 1000
    # Set up some chains
    chain_indices = list(range(num_chains))
    swap_mat = [[0]* num_chains for i in range(num_chains)]
    attempted_swap_mat = [[0]* num_chains for i in range(num_chains)]
    chain_list = []
    posterior_power = 1.0
    for i in range(num_chains):
        new_chain = {'posterior_power': posterior_power,
                     'theta': deepcopy(theta),
                     'ln_posterior': ln_posterior,
                     'ln_likelihood': ln_likelihood,
                     'ln_prior': ln_prior,
                    }
        chain_list.append(new_chain)
        posterior_power *= .9
    # This is the Metropolis-Hastings algorithm
    for i in range(num_it):
        if (i % sample_freq) == 0:
            c = chain_list[0]
            theta, ln_posterior, ln_likelihood, ln_prior = c['theta'], c['ln_posterior'], c['ln_likelihood'], c['ln_prior']
            sample_chain(out, i, ln_posterior, ln_likelihood, ln_prior, theta)
        for chain_index, chain in enumerate(chain_list):
            prev_ln_posterior = chain['ln_posterior']
            proposed_theta, ln_hastings_ratio = propose_a_new_theta(deepcopy(chain['theta']))
            # Calculate the prior and ln_like for the proposal
            ln_prior = calc_ln_prior(proposed_theta)
            ln_likelihood = calc_ln_likelihood(proposed_theta, data)
            ln_posterior = ln_likelihood + ln_prior
            # Take into account the chain's "heating"
            ln_posterior_ratio = ln_posterior - prev_ln_posterior
            ln_posterior_ratio *= chain['posterior_power']
            ln_acceptance_ratio = ln_posterior_ratio + ln_hastings_ratio
            if log(random.random()) < ln_acceptance_ratio:
                chain['theta'] = proposed_theta
                chain['ln_posterior'] = ln_posterior
                chain['ln_likelihood'] = ln_likelihood
                chain['ln_prior'] = ln_prior
        if num_chains > 1:
            # Chain swapping
            random.shuffle(chain_indices)
            sA, sB = chain_indices[0], chain_indices[1]
            cA, cB = chain_list[sA], chain_list[sB]
            lpA, lpB = cA['ln_posterior'], cB['ln_posterior']
            attempted_swap_mat[sA][sB] += 1
            attempted_swap_mat[sB][sA] += 1
            if lpA > lpB:
                # chain "A"  might reject the swap
                ln_posterior_ratio = (lpB - lpA)*cA['posterior_power']
            else:
                ln_posterior_ratio = (lpA - lpB)*cB['posterior_power']
            if log(random.random()) < ln_posterior_ratio:
                # Accept the swap
                swap_mat[sA][sB] += 1
                swap_mat[sB][sA] += 1
                cA['posterior_power'], cB['posterior_power'] = cB['posterior_power'], cA['posterior_power']
                chain_list[sA], chain_list[sB] = chain_list[sB], chain_list[sA]
    if num_chains > 1:
        # write out some diagnostics
        for i in range(num_chains):
            for j in range(1 + i, num_chains):
                attempts, swaps = attempted_swap_mat[i][j], swap_mat[i][j]
                pct = 100.0*swaps/attempts
                msg = 'Accepted {} out of {} ({} %) attempted swaps between chains {} and {}\n'
                msg = msg.format(swaps, attempts, pct, i, j)
                sys.stderr.write(msg)

if __name__ == '__main__':
    try:
        datafn = sys.argv[1]
        num_it = int(sys.argv[2])
        assert(num_it > 0)
        theta = get_initial_parameters(sys.argv[3:])
    except:
        sys.exit('Expecting:\npython continuous-mcmc.py <data filename> <# iter> <start mu> <start sigma>\n')
    try:
        data = read_data(datafn)
        assert len(data) > 0
    except IOError:
        sys.exit('Data file "{}" does not exist.'.format(datafn))
    sample_filename = 'out.tsv'
    mcmcmc(theta,
           data,
           num_chains=1,
           out=open(sample_filename, 'w'))
