#!/usr/bin/env python 
'''
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import random
import math
import sys


try:
    num_it = int(sys.argv[1])
    assert(num_it > 0)
    theta = [float(sys.argv[2]), float(sys.argv[3])]
    datafn = sys.argv[4]
except:
    sys.exit('Expecting:\npython continuous-mcmc.py <# iter> <start mu> <start sigma> <data filename>\n')
try:
    data = [float(i.strip()) for i in open(datafn) if i.strip()]
    assert len(data) > 0
except IOError:
    sys.exit('Data file "{}" does not exist.'.format(datafn))
except ValueError:
    sys.exit('Data file "{}" was expected to have one number per line.'.format(datafn))
except AssertionError:
    sys.exit('Data file "{}" was empty.'.format(datafn))

def calc_ln_likelihood(theta):
    mu, sigma = theta
    ln_sigma = math.log(sigma)
    ln_like = -len(data)*ln_sigma
    s = 0.0
    for h in data:
        s -= (h - mu)**2
    s /= (2*sigma*sigma) 
    ln_like += s
    return ln_like

ln_likelihood = calc_ln_likelihood(theta)
ln_prior = 0.0  #TEMP ASSUMES "flat" priors BOOO!!! bad form!!!
ln_posterior = ln_likelihood + ln_prior

mu_window = 20.0
sigma_window = 20.0
# This is MCMC using the Metropolis algorithm:
out = sys.stdout
out.write("Iter\tlike\tmu\tsigma\n")
n_prop_mu = 0
n_prop_sigma = 0
n_accept_mu = 0
n_accept_sigma = 0

# Thinning
# Aim to collect 10000 samples
if num_it < 10000:
    sample_freq = 1
else:
    sample_freq = num_it // 10000

# Set up some chains
num_chains = 4
chain_indices = range(num_chains)
swap_mat = [[0]* num_chains for i in range(num_chains)]
attempted_swap_mat = [[0]* num_chains for i in range(num_chains)]
chain_list = []
posterior_power = 1.0
for i in range(num_chains):
    new_chain = {'posterior_power': posterior_power,
                 'theta': list(theta),
                 'ln_posterior': ln_posterior,
                }
    chain_list.append(new_chain)
    posterior_power *= .99

# This is the Metropolis-Hastings algorithm

for i in range(num_it):
    if (1 + i) % sample_freq == 0:
        theta, ln_posterior = chain_list[0]['theta'], chain_list[0]['ln_posterior']
        out.write("{}\t{}\t{}\t{}\n".format(1 + i, ln_posterior, theta[0], theta[1]))
    for chain_index, chain in enumerate(chain_list):
        prev_theta = list(chain['theta'])
        prev_ln_posterior = chain['ln_posterior']
        theta = chain['theta']
        if random.random() < 0.5:
            # change mu
            u = random.random() - 0.5
            diff = mu_window * u
            theta[0] += diff
            if chain_index == 0:
                proposed_mu = True
                n_prop_mu += 1
        else:
            # change sigma
            u = random.random() - 0.5
            diff = sigma_window * u
            theta[1] += diff
            if theta[1] < 0.0:
                theta[1] = -theta[1]
            if chain_index == 0:
                proposed_mu = False
                n_prop_sigma += 1
        # Prior ratio is 1.0 if we use (improper) uniform priors, so we could ignore it...
        prior_ratio = 1.0 #TEMP ASSUMES "flat" priors BOOO!!! bad form!!!
        ln_likelihood = calc_ln_likelihood(theta)
        ln_posterior = ln_likelihood + ln_prior

        # Take into account the chain's "heating"
        ln_posterior_ratio = ln_posterior - prev_ln_posterior
        ln_posterior_ratio *= chain['posterior_power']
        # If the Hastings ratio is 1.0, we can ignore it...
        hastings_ratio = 1.0
        ln_hastings_ratio = 0.0

        ln_acceptance_ratio = ln_posterior_ratio + ln_hastings_ratio
        if math.log(random.random()) < ln_acceptance_ratio:
            if chain_index == 0:
                if proposed_mu:
                    n_accept_mu += 1
                else:
                    n_accept_sigma += 1
            chain['theta'] = theta
            chain['ln_posterior'] = ln_posterior
        else:
            pass
            #theta = prev_theta
            #ln_likelihood = prev_ln_likelhood
    # Chain swapping
    random.shuffle(chain_indices)
    sA, sB = chain_indices[0], chain_indices[1]
    cA, cB = chain_list[sA], chain_list[sB]
    lpA, lpB = cA['ln_posterior'], cB['ln_posterior']
    attempted_swap_mat[sA][sB] += 1
    attempted_swap_mat[sB][sA] += 1
        
    if lpA > lpB:
        # chain "A"  might reject the swap
        ln_posterior_ratio = (lpB - lpA)*cA['posterior_power']
    else:
        ln_posterior_ratio = (lpA - lpB)*cB['posterior_power']
    if math.log(random.random()) < ln_posterior_ratio:
        # Accept the swap
        swap_mat[sA][sB] += 1
        swap_mat[sB][sA] += 1
        cA['posterior_power'], cB['posterior_power'] = cB['posterior_power'], cA['posterior_power']
        chain_list[sA], chain_list[sB] = chain_list[sB], chain_list[sA]



sys.stderr.write('mu accept %    = {:6.5f}\n'.format(n_accept_mu/n_prop_mu))
sys.stderr.write('sigma accept % = {:6.5f}\n'.format(n_accept_sigma/n_prop_sigma))
for i in range(num_chains):
    for j in range(1 + i, num_chains):
        attempts, swaps = attempted_swap_mat[i][j], swap_mat[i][j]
        pct = 100.0*swaps/attempts
        msg = 'Accepted {} out of {} ({} %) attempted swaps between chains {} and {}\n'
        msg = msg.format(swaps, attempts, pct, i, j)
        sys.stderr.write(msg)