
def foward_backward(obs, states, start_p, trans_p, emit_p):

	alpha=[{} for j in range(len(obs))] # forward:: alpha[j][X] is probability that true genotye is X at marker j (starts at 0)

	for y in states:
		alpha[0][y] = start_p[y] * emit_p[y][obs[0]]


	for t in xrange(1, len(obs)):
        	for y in states:
			alpha[t][y] = 0.0
			for y0 in states: # y0 is state at t-1
				alpha[t][y] +=alpha[t-1][y0] * trans_p[y0][y] * emit_p[y][obs[t]]

	beta=[{} for j in range(len(obs))] # backward:: beta[j][X] is probability that true genotye is X at marker j (starts at 0)
	for y in states:
		beta[len(obs)-1][y] = 1.0 #start_p[y]

	for t in xrange(len(obs)-2,-1,-1):
		#beta.append({})
        	for y in states:
			beta[t][y] = 0.0 # y is state at t
			for y0 in states: # y0 is state at t+1
				beta[t][y] +=beta[t+1][y0] * trans_p[y][y0] * emit_p[y0][obs[t+1]]

	return alpha,beta

def viterbi(obs, states, start_p, trans_p, emit_p):
    V = [{}]
    path = {}
    
    # Initialize base cases (t == 0)
    for y in states:
        V[0][y] = start_p[y] * emit_p[y][obs[0]]
        path[y] = [y]
    
    # Run Viterbi for t > 0
    for t in range(1, len(obs)):
        V.append({})
        newpath = {}

        for y in states:
            (prob, state) = max((V[t-1][y0] * trans_p[y0][y] * emit_p[y][obs[t]], y0) for y0 in states)
            V[t][y] = prob
            newpath[y] = path[state] + [y]

        # Don't need to remember the old paths
        path = newpath
    n = 0           # if only one element is observed max is sought in the initialization values
    if len(obs) != 1:
        n = t
    print_dptable(V)
    (prob, state) = max((V[n][y], y) for y in states)
    return (prob, path[state])

# just prints a table of the steps.
def print_dptable(V):
    s = "    " + " ".join(("%7d" % i) for i in range(len(V))) + "\n"
    for y in V[0]:
        s += "%.5s: " % y
        s += " ".join("%.7s" % ("%f" % v[y]) for v in V)
        s += "\n"
    print(s)

def example():
    return viterbi(observations,
                   states,
                   start_probability,
                   transition_probability,
                   emission_probability)



def fwdbkwd():
    return 	
	
states = ('AA','AB','BB')
 
observations = ('a','a','a','a','a','a','b','b','b','b','b','b','a','a','a','a','a','a')
 
start_probability = {'AA':0.25,'AB':0.5,'BB':0.25}
 
transition_probability = {
   'AA' : {'AA':0.9,'AB':0.1,'BB':0.0},
   'AB' : {'AA':0.05,'AB':0.9,'BB':0.05},
   'BB' : {'AA':0.0,'AB':0.1,'BB':0.9}
   }
 
emission_probability = {
   'AA' : {'a':0.999,'b':0.001},
   'AB' : {'a':0.5,'b':0.5},
   'BB' : {'a':0.001,'b':0.999}
   }
   
print(example())	

fprbs,rprbs=foward_backward(observations,
                   states,
                   start_probability,
                   transition_probability,
                   emission_probability)
print

postProb=[{} for j in range(len(observations))] # forward:: alpha[j][X] is probability that true genotye is X at marker j (starts at 0)

for j in range(len(fprbs)):
	denom=0.0
	for y in states: 
		denom+=(fprbs[j][y]*rprbs[j][y])
	for y in states: 
		postProb[j][y]=(fprbs[j][y]*rprbs[j][y])/denom

	print j,postProb[j]

   
