#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, unicode_literals
from scipy import optimize
from math import log, exp
import sys

WORST_LN_L = float('-inf')
    
def ln_likelihood(data, parameters):
    out_branch = parameters[0]
    in1_branch = parameters[1]
    if len(parameters) == 2:
        in2_branch = in1_branch
    else:
        assert len(parameters) == 3
        in2_branch = parameters[2]
    if out_branch < 0.0 or in1_branch < 0.0 or in2_branch < 0.0:
        return WORST_LN_L
    # THESE formulae are NOT CORRECT - these probabilities should depend on the
    #   parameters
    prob_same = 0.1
    prob_1_differs = 0.1
    prob_2_differs = 0.1
    prob_all_differ = 0.1
    prob_out_differ = 0.1
    # END BOGUS Probabilities
    lnL = 0.0
    # This is where you'll need to combine the probabilities and the data
    # to calculate that lnL
    return lnL

def main(input_filepath):
    real_data = read_data(input_filepath)
    real_results = analyze_data(real_data, True)
    real_lrt, real_null_mles, real_alt_mles = real_results
    print('lrt = ', real_lrt)
    print('real_null_mles = ', real_null_mles)
    print('real_alt_mles  = ', real_alt_mles)

def analyze_data(data, print_details=False):
    # If your alternative model has 3 parameters, and you want them 
    # to all start at values of 0.05
    alt_mle, alt_ln_like = estimate_MLE(data, [0.05, 0.05, 0.05])
    null_mle, null_ln_like = estimate_MLE(data, [0.05, 0.05])
    if print_details:
        print('null_ln_like = ', null_ln_like)
        print('alt_ln_like  = ', alt_ln_like)
    lrt = 2*(alt_ln_like - null_ln_like)
    return lrt, null_mle, alt_mle

def estimate_MLE(data, params0):  
    param_opt = optimize.fmin(scipy_ln_likelihood,
                              x0=list(params0),
                              args=(data,),
                              xtol=1e-8,
                              disp=False)
    return param_opt, -scipy_ln_likelihood(param_opt, data)

def scipy_ln_likelihood(parameters, *valist):
    '''Deals with the fact that fmin will call this function with
    a variable length argument list, valist containing the data.
    Here we call a function that expects the data first, then the
    parameters to calculate a log likedhood. We return the negative
    ln L so that scipy's minimize function will find the max ln L.
    '''
    assert len(valist) == 1
    data = valist[0]
    negLnL = -ln_likelihood(data, parameters)
    return negLnL

def read_fasta(fn):
    '''generator that returns name, sequence pairs for each
    DNA sequence in a FASTA file.
    '''
    import re
    valid_seq_pattern = re.compile(r'^[-?ACGT]*$')
    with open(fn, 'rU') as inp_stream:
        name = None
        seq = []
        for line in inp_stream:
            if line.startswith('>'):
                next_name = line[1:-1] # remove the >
                if name is not None:
                    yield name, ''.join(seq)
                name = next_name
                seq = []
            else:
                ls = line.strip()
                # split by whitespace and rejoin to remove any spaces
                spaceless = ''.join(ls.split())
                if not valid_seq_pattern.match(spaceless):
                    msg = 'Illegal sequence\n"{}"\nfound for OTU "{}"'
                    raise ValueError(msg.format(ls, name))
                seq.append(spaceless)
    yield name, ''.join(seq)

def rows_to_columns(by_row):
    if not by_row:
        return
    n_rows = len(by_row)
    n_col = len(by_row[0])
    for i in range(n_col):
        yield [row[i] for row in by_row]

class Blob(object):
    pass

def read_data(fn):
    seq_list = []
    for name, sequence in read_fasta(fn):
        sys.stderr.write('Read sequence for "{}"\n'.format(name))
        seq_list.append(sequence)
    assert len(seq_list) == 3
    n_same = 0
    n_1_differs = 0
    n_2_differs = 0
    n_out_differs = 0
    n_all_differ = 0
    for column in rows_to_columns(seq_list):
        if ('?' in column) or ('-' in column):
            continue
        s_out, s_1, s_2 = column
        if s_out == s_1:
            # out = in1
            if s_out == s_2:
                n_same += 1
            else:
                n_2_differs += 1
        elif s_out == s_2:
            n_1_differs += 1
        elif s_1 == s_2:
            n_out_differs += 1
        else:
            n_all_differ += 1
    return [n_same, n_1_differs, n_2_differs, n_out_differs, n_all_differ]
    return {'n_same': n_same, 
            'n_all_differ': n_1_differs,
            'n_2_differs': n_2_differs,
            'n_out_differs': n_out_differs,
            'n_all_differ': n_all_differ}
    x = Blob()
    x.n_same = n_same
    x.n_all_differ = n_1_differs
    x.n_2_differs = n_2_differs
    x.n_out_differs =  n_out_differs
    x.n_all_differ = n_all_differ
    return x

main(sys.argv[1])

