#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, unicode_literals
from scipy import optimize
from math import log, exp
import sys

try:
    CURRENT_DEFAULT_RHO = float(sys.argv[2])
except:
    CURRENT_DEFAULT_RHO = 1.0

WORST_LN_L = float('-inf')
GLOBAL_MAX_LN_L = None
FIND_LOWER = True
CI_CUTOFF_LN_L = None
def p_same_change(brlen):
    et = exp(-4.0*brlen/3.0)
    s = 0.25*(1 + 3*et)
    c = 0.25*(1 - et)
    return s, c


def ln_likelihood(data, parameters):
    global CURRENT_DEFAULT_RHO
    out_branch = parameters[0]
    in1_branch = parameters[1]
    if len(parameters) == 2:
        in2_branch = in1_branch
        rho = CURRENT_DEFAULT_RHO
    else:
        assert len(parameters) == 3
        rho = parameters[2]
    in2_branch = rho*in1_branch
    if out_branch < 0.0 or in1_branch < 0.0 or in2_branch < 0.0:
        return WORST_LN_L
    #branch specific
    p_same_out, p_change_out = p_same_change(out_branch)
    p_same_in1, p_change_in1 = p_same_change(in1_branch)
    p_same_in2, p_change_in2 = p_same_change(in2_branch)
    all_change = p_change_out*p_change_in1*p_change_in2
    prob_same = 0.25*(p_same_out*p_same_in1*p_same_in2 + 3*all_change)
    prob_1_differs = 0.25*(p_same_out*p_change_in1*p_same_in2
                           + p_change_out*p_same_in1*p_change_in2
                           + 2*all_change)
    prob_2_differs = 0.25*(p_same_out*p_same_in1*p_change_in2
                           + p_change_out*p_change_in1*p_same_in2
                           + 2*all_change)
    prob_all_differ = 0.25*(p_same_out*p_change_in1*p_change_in2 
                            + p_change_out*p_same_in1*p_change_in2
                            + p_change_out*p_change_in1*p_same_in2
                            + all_change)
    prob_out_differs = 0.25*(p_same_out*p_change_in1*p_change_in2
                           + p_change_out*p_same_in1*p_same_in2
                           + 2*all_change)

    total_prob = 4*prob_same + 12*(prob_1_differs + prob_out_differs + prob_2_differs) + 24*prob_all_differ
    assert abs(total_prob - 1.0) < 1.0e-05
    
    n_same, n_1_differs, n_2_differs, n_out_differs, n_all_differ = data
    lnL = 0.0
    if n_same > 0:
        lnL += n_same*log(prob_same)
    if n_1_differs > 0:
        lnL += n_1_differs*log(prob_1_differs)
    if n_2_differs > 0:
        lnL += n_2_differs*log(prob_2_differs)
    if n_out_differs > 0:
        lnL += n_out_differs*log(prob_out_differs)
    if n_all_differ > 0:
        lnL += n_all_differ*log(prob_all_differ)
    return lnL

def main(input_filepath):
    data = read_data(input_filepath)
    print(data)
    results = analyze_data(data, True)
    lrt, null_mles, alt_mles = results
    print('lrt = ', lrt)
    #print('null_mles: nu_out={}, nu_1 = {} rho = 1.0'.format(*null_mles))
    #print('alt_mles:  nu_out={}, nu_1 = {} rho = {}'.format(*alt_mles))

def analyze_data(data, print_details=False):
    global GLOBAL_MAX_LN_L, FIND_LOWER, CI_CUTOFF_LN_L, RHO_MLE
    # If your alternative model has 3 parameters, and you want them 
    # to all start at values of 0.05
    alt_mle, alt_ln_like = estimate_MLE(data, [0.05, 0.05, 0.05])
    null_mle, null_ln_like = estimate_MLE(data, [0.05, 0.05])
    if print_details:
        print('null_ln_like = ', null_ln_like)
        print('alt_ln_like  = ', alt_ln_like)
        print('null_mles: nu_out={}, nu_1 = {} rho = '.format(*null_mle) + str(CURRENT_DEFAULT_RHO))
        print('alt_mles:  nu_out={}, nu_1 = {} rho = {}'.format(*alt_mle))
    lrt = 2*(alt_ln_like - null_ln_like)
    GLOBAL_MAX_LN_L = alt_ln_like
    CI_CUTOFF_LN_L = alt_ln_like - (3.84/2.0)
    RHO_MLE = alt_mle[2]
    lower_MLE, lower_ln_like = line_search_for_rho(data, alt_mle, 0, RHO_MLE*0.99, RHO_MLE)
    upper_MLE, upper_ln_like = line_search_for_rho(data, alt_mle, RHO_MLE, 1.01*RHO_MLE, 10000)
    print('95% confident {} <= rho <= {}'.format(lower_MLE[2], upper_MLE[2]))
    return lrt, null_mle, alt_mle

def line_search_for_rho(data, params, lower_rho, ini_rho, upper_rho):
    assert CI_CUTOFF_LN_L is not None
    nuisance_params = [params[0], params[1]]
    param_opt = optimize.brent(diff_from_cutoff_rho,
                               args=(data, nuisance_params),
                               brack=(lower_rho, ini_rho, upper_rho))
    CURRENT_DEFAULT_RHO = param_opt
    nuisance_param_opt, lnL = estimate_MLE(data, nuisance_params)
    full_param = [nuisance_param_opt[0], nuisance_param_opt[1], CURRENT_DEFAULT_RHO]
    return full_param, lnL

def diff_from_cutoff_rho(rho, data, nuisance_params):
    global CURRENT_DEFAULT_RHO
    assert CI_CUTOFF_LN_L is not None
    #print('checking rho = ', rho)
    CURRENT_DEFAULT_RHO = rho
    nuisance_param_opt, lnL = estimate_MLE(data, nuisance_params)
    optLnL = -scipy_ln_likelihood(nuisance_param_opt, data)
    dist = abs(optLnL - CI_CUTOFF_LN_L)
    #print('   distance from cutoff = ', dist)
    return dist

def estimate_MLE(data, params0):
    param_opt = optimize.fmin(scipy_ln_likelihood,
                              x0=list(params0),
                              args=(data,),
                              xtol=1e-8,
                              disp=False)
    return param_opt, -scipy_ln_likelihood(param_opt, data)

def scipy_ln_likelihood(parameters, *valist):
    '''Deals with the fact that fmin will call this function with
    a variable length argument list, valist containing the data.
    Here we call a function that expects the data first, then the
    parameters to calculate a log likedhood. We return the negative
    ln L so that scipy's minimize function will find the max ln L.
    '''
    assert len(valist) == 1
    data = valist[0]
    negLnL = -ln_likelihood(data, parameters)
    return negLnL

def read_fasta(fn):
    '''generator that returns name, sequence pairs for each
    DNA sequence in a FASTA file.
    '''
    import re
    valid_seq_pattern = re.compile(r'^[-?ACGT]*$')
    with open(fn, 'rU') as inp_stream:
        name = None
        seq = []
        for line in inp_stream:
            if line.startswith('>'):
                next_name = line[1:-1] # remove the >
                if name is not None:
                    yield name, ''.join(seq)
                name = next_name
                seq = []
            else:
                ls = line.strip()
                # split by whitespace and rejoin to remove any spaces
                spaceless = ''.join(ls.split())
                if not valid_seq_pattern.match(spaceless):
                    msg = 'Illegal sequence\n"{}"\nfound for OTU "{}"'
                    raise ValueError(msg.format(ls, name))
                seq.append(spaceless)
    yield name, ''.join(seq)

def rows_to_columns(by_row):
    if not by_row:
        return
    n_rows = len(by_row)
    n_col = len(by_row[0])
    for i in range(n_col):
        yield [row[i] for row in by_row]

class Blob(object):
    pass

def read_data(fn):
    seq_list = []
    for name, sequence in read_fasta(fn):
        sys.stderr.write('Read sequence for "{}"\n'.format(name))
        seq_list.append(sequence)
    assert len(seq_list) == 3
    n_same = 0
    n_1_differs = 0
    n_2_differs = 0
    n_out_differs = 0
    n_all_differ = 0
    for column in rows_to_columns(seq_list):
        if ('?' in column) or ('-' in column):
            continue
        s_out, s_1, s_2 = column
        if s_out == s_1:
            # out = in1
            if s_out == s_2:
                n_same += 1
            else:
                n_2_differs += 1
        elif s_out == s_2:
            n_1_differs += 1
        elif s_1 == s_2:
            n_out_differs += 1
        else:
            n_all_differ += 1
    return [n_same, n_1_differs, n_2_differs, n_out_differs, n_all_differ]

try:
    fn = sys.argv[1]
except:
    sys.exit('Expecting 1 argument: a FASTA filepath.')
main(fn)

