#!/usr/bin/env python 
'''
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import random
from math import log
from copy import deepcopy
import sys
###############################
# You WILL NEED to modify the functions in this section...

param_names = ('DecayRateLStart',
               'DecayRateMStart',
               'DecayRateHStart',
               'ProbChangeSectorGivenMoveLStart',
               'ProbChangeSectorGivenMoveMStart',
               'ProbChangeSectorGivenMoveHStart',
               'ProbChangeHeightGivenChangeSector',
               'ProbLToMGivenChangeHeight',
               'ProbMToHGivenChangeHeight',
               'ProbHToLGivenChangeHeight',
               )
class ENUM(object):
    pass
PARAMS = ENUM()
for ind, n in enumerate(param_names):
    setattr(PARAMS, n, ind)

NUM_PARAMS = len(param_names)
NUM_DECAY_PARAMS = 3
DECAY_PROPOSAL_WINDOW = .05
PROB_PROPOSAL_WINDOW = 0.1
SLIDING_WINDOW_SIZES = [DECAY_PROPOSAL_WINDOW,
                        DECAY_PROPOSAL_WINDOW,
                        DECAY_PROPOSAL_WINDOW,
                        PROB_PROPOSAL_WINDOW,
                        PROB_PROPOSAL_WINDOW,
                        PROB_PROPOSAL_WINDOW,
                        PROB_PROPOSAL_WINDOW,
                        PROB_PROPOSAL_WINDOW,
                        PROB_PROPOSAL_WINDOW,
                        PROB_PROPOSAL_WINDOW,]
                        
# hazard param for the exponential prior on the decay rates
DECAY_RATE_PRIOR_HAZARD = 1.0

def get_initial_parameters(command_line_args):
    '''As written, this will just return a floating point number for every argument
    sent in.
    This template strips of the initial data file name and # of 
    iterations and then calls this function to process the rest of the arguments.
    '''
    assert len(command_line_args) == len(param_names)
    return [float(i) for i in command_line_args]

# Since the operations are so similar for each start state, we can write 
#   generic ln_likelihood code if we know the indices for the parameters that
#   depend on each start state.
# This dict maps a start height to a tuple of:
#       [0] decay rate
#       [1] prob of change of sector
#       [2] the "alt" height used to interpret the probability of the next
#               parameter that controls the height.
#       [3] prob of the move to the alt height given the fact that the bird
#               is changing height
PARAM_INDICES = {
    'L': (PARAMS.DecayRateLStart,
          PARAMS.ProbChangeSectorGivenMoveLStart,
          'M',
          PARAMS.ProbLToMGivenChangeHeight),
    'M': (PARAMS.DecayRateMStart,
          PARAMS.ProbChangeSectorGivenMoveMStart,
          'H',
          PARAMS.ProbMToHGivenChangeHeight),
    'H': (PARAMS.DecayRateHStart,
          PARAMS.ProbChangeSectorGivenMoveHStart,
          'L',
          PARAMS.ProbHToLGivenChangeHeight),
}

def calc_ln_move_prob(theta,
                      prev_sector, prev_height,
                      curr_sector, curr_height, 
                      next_sector, next_height,
                      waiting_time):
    if curr_height == next_height and curr_sector == next_sector:
        return 0.0 # this covers for my sloppy filtering of the inputs
    # Each move has a waiting time, some event regarding the sector, and 
    #   some height event. Look up the indices that describe the probabilities
    #   for this start state
    decay_ind, p_sect_ind, alt_height, p_alt_int = PARAM_INDICES[curr_height]
    decay_rate = theta[decay_ind]
    p_sect_move = theta[p_sect_ind]
    p_alt_height = theta[p_alt_int]
    # Now look at this datum to figure out what probabilities are relevant...
    if curr_sector == next_sector:
        p_sector_event = 1 - p_sect_move
        if next_height == alt_height:
            p_height_event = p_alt_height
        else:
            p_height_event = 1 - p_alt_height
    else:
        p_sector_event = p_sect_move
        p_ch = theta[PARAMS.ProbChangeHeightGivenChangeSector]
        if next_height == curr_height:
            p_height_event = 1 - p_ch
        elif next_height == alt_height:
            p_height_event = p_ch*p_alt_height
        else:
            p_height_event = p_ch*(1 - p_alt_height)
    # Now combine the event probabilities to get this datum's likelihood contribution
    ln_like = log(decay_rate) -decay_rate*waiting_time + log(p_sector_event) + log(p_height_event)
    return ln_like

LN_ONE_THIRD = log(1.0 / 3.0)
def calc_ln_prob_start_pos(theta, first_sector, first_height, second_sector, second_height):
    return LN_ONE_THIRD

def calc_ln_prior(theta):
    ####
    # Uniform on the probabilities. Exponential(1) on the 3 decay rates
    ln_prior = 0
    ln_prior += -DECAY_RATE_PRIOR_HAZARD*(theta[PARAMS.DecayRateLStart])
    ln_prior += -DECAY_RATE_PRIOR_HAZARD*(theta[PARAMS.DecayRateMStart])
    ln_prior += -DECAY_RATE_PRIOR_HAZARD*(theta[PARAMS.DecayRateHStart])
    return ln_prior

def propose_a_new_theta(theta_to_alter):
    # Choose a parameter for which we will propose a new value
    ind_to_mod = random.randrange(0, NUM_PARAMS)
    # Sliding window move, but with different proposal sizes for the decay rate
    #   probability parameters....
    prev_value = theta_to_alter[ind_to_mod]
    window_size = SLIDING_WINDOW_SIZES[ind_to_mod]
    u = random.random() - 0.5
    shift = u*window_size
    new_value = prev_value + shift
    # Make sure the parameters stay in their legal ranges...
    if new_value < 0.0:
        new_value = -new_value
    if (ind_to_mod >= NUM_DECAY_PARAMS) and (new_value > 1.0):
        new_value = 2.0 - new_value
    # Update the list of parameter values
    theta_to_alter[ind_to_mod] = new_value
    # sliding window move is symmetric...
    hastings_ratio = 1.0
    ln_hastings_ratio = log(hastings_ratio)
    return theta_to_alter, ln_hastings_ratio

def write_header(out):
    out.write('iteration\tlnPosterior\t{}\n'.format('\t'.join(param_names)))

###############################
# you MIGHT need to modify the arguments in this section (if you)
#   decide to represent the parameters (theta) as something other than
#    a list of numbers.

def sample_chain(out, step_index, ln_posterior, theta):
    p = '\t'.join([str(i) for i in theta])
    line = '{i}\t{x}\t{p}\n'.format(i=step_index, x=ln_posterior, p=p)
    #sys.stdout.write(line)
    out.write(line)

###############################
# you probably will NOT need to modify the code below here
################################

# Data reading/representation
# The data will just be a list of MovementSeries objects
class Movement(object):
    def __init__(self, start_sector, start_height, end_sector, end_height, waiting_time):
        self.start_sector = start_sector
        self.start_height = start_height
        self.end_sector = end_sector
        self.end_height = end_height
        self.waiting_time = waiting_time

class MovementSeries(object):
    def __init__(self, box_id, series_id, first_movement):
        self.box_id = box_id
        self.series_id = series_id
        self.move_list = [first_movement]
    def add_movement(self, m):
        self.move_list.append(m)
        
def read_data(fn):
    expected_header = 'box_id\tseries_id\tstart_sector\tstart_height\tend_sector\tend_height\twaiting_time\n'
    d = []
    by_series_id = {}
    with open(fn, 'rU') as inp:
        for line_number, row in enumerate(inp):
            if line_number == 0:
                if row != expected_header:
                    raise ValueError('File did not have the expected header:\n{}\n'.format(expected_header))
            else:
                try:
                    box_id, series_id, s_s, s_h, e_s, e_h, w_t = row.strip().split('\t')
                except:
                    raise ValueError('Incorrect number of columns in line {}\n'.format(line_number))
                try:
                    w_t = float(w_t)
                except:
                    raise ValueError('Expecting a number as a waiting_time found {}\n'.format(w_t))
                movement = Movement(s_s, s_h, e_s, e_h, w_t)
                ms = by_series_id.get(series_id)
                if ms is None:
                    ms = MovementSeries(box_id, series_id, movement)
                    d.append(ms)
                    by_series_id[series_id] = ms
                else:
                    ms.add_movement(movement)
    return d
# END of data reading/representation

def calc_ln_likelihood(theta, data):
    ln_like = 0.0
    for move_series in data:
        move_list = move_series.move_list
        first_record = move_list[0]
        prev_sector, prev_height = first_record.start_sector, first_record.start_height
        curr_sector, curr_height = first_record.end_sector, first_record.end_height
        ln_like += calc_ln_prob_start_pos(theta, prev_sector, prev_height, curr_sector, curr_height)
        for record in move_list[1:]:
            next_sector, next_height = record.end_sector, record.end_height
            ln_like += calc_ln_move_prob(theta,
                                         prev_sector, prev_height,
                                         curr_sector, curr_height, 
                                         next_sector, next_height,
                                         record.waiting_time)
            prev_sector, prev_height = curr_sector, curr_height
            curr_sector, curr_height = next_sector, next_height
    return ln_like

def mcmcmc(initial_theta, data, num_chains, out):
    assert num_chains > 0
    theta = deepcopy(initial_theta)
    ln_likelihood = calc_ln_likelihood(theta, data)
    ln_prior = calc_ln_prior(theta)
    ln_posterior = ln_likelihood + ln_prior
    # This is MCMC using the Metropolis algorithm:
    write_header(out)
    # Thinning - Sample only every 100 step
    sample_freq = 100
    # Set up some chains
    chain_indices = list(range(num_chains))
    swap_mat = [[0]* num_chains for i in range(num_chains)]
    attempted_swap_mat = [[0]* num_chains for i in range(num_chains)]
    chain_list = []
    posterior_power = 1.0
    for i in range(num_chains):
        new_chain = {'posterior_power': posterior_power,
                     'theta': deepcopy(theta),
                     'ln_posterior': ln_posterior,
                    }
        chain_list.append(new_chain)
        posterior_power *= .9
    # This is the Metropolis-Hastings algorithm
    for i in range(num_it):
        if (i % sample_freq) == 0:
            theta, ln_posterior = chain_list[0]['theta'], chain_list[0]['ln_posterior']
            sample_chain(out, i, ln_posterior, theta)
        for chain_index, chain in enumerate(chain_list):
            prev_ln_posterior = chain['ln_posterior']
            proposed_theta, ln_hastings_ratio = propose_a_new_theta(deepcopy(chain['theta']))
            # Calculate the prior and ln_like for the proposal
            ln_prior = calc_ln_prior(proposed_theta)
            ln_likelihood = calc_ln_likelihood(proposed_theta, data)
            ln_posterior = ln_likelihood + ln_prior
            # Take into account the chain's "heating"
            ln_posterior_ratio = ln_posterior - prev_ln_posterior
            ln_posterior_ratio *= chain['posterior_power']
            ln_acceptance_ratio = ln_posterior_ratio + ln_hastings_ratio
            if log(random.random()) < ln_acceptance_ratio:
                chain['theta'] = proposed_theta
                chain['ln_posterior'] = ln_posterior
        if num_chains > 1:
            # Chain swapping
            random.shuffle(chain_indices)
            sA, sB = chain_indices[0], chain_indices[1]
            cA, cB = chain_list[sA], chain_list[sB]
            lpA, lpB = cA['ln_posterior'], cB['ln_posterior']
            attempted_swap_mat[sA][sB] += 1
            attempted_swap_mat[sB][sA] += 1
            if lpA > lpB:
                # chain "A"  might reject the swap
                ln_posterior_ratio = (lpB - lpA)*cA['posterior_power']
            else:
                ln_posterior_ratio = (lpA - lpB)*cB['posterior_power']
            if log(random.random()) < ln_posterior_ratio:
                # Accept the swap
                swap_mat[sA][sB] += 1
                swap_mat[sB][sA] += 1
                cA['posterior_power'], cB['posterior_power'] = cB['posterior_power'], cA['posterior_power']
                chain_list[sA], chain_list[sB] = chain_list[sB], chain_list[sA]
    if num_chains > 1:
        # write out some diagnostics
        for i in range(num_chains):
            for j in range(1 + i, num_chains):
                attempts, swaps = attempted_swap_mat[i][j], swap_mat[i][j]
                pct = 100.0*swaps/attempts
                msg = 'Accepted {} out of {} ({} %) attempted swaps between chains {} and {}\n'
                msg = msg.format(swaps, attempts, pct, i, j)
                sys.stderr.write(msg)

if __name__ == '__main__':
    try:
        datafn = sys.argv[1]
        num_it = int(sys.argv[2])
        assert(num_it > 0)
        theta = get_initial_parameters(sys.argv[3:])
    except:
        mfmt='Expecting:\npython mth-answer.py <data filename> <# iter> then the starting values for the {} parameters:\n  {}\n'
        sys.exit(mfmt.format(len(param_names), '\n  '.join(param_names)))
    try:
        data = read_data(datafn)
        assert len(data) > 0
    except IOError:
        sys.exit('Data file "{}" does not exist.'.format(datafn))
    mcmcmc(theta,
           data,
           num_chains=1,
           out=sys.stdout)
