One step forward, two steps back
10s 256 MB
Implement the forward-backward algorithm for a Hidden Markov Model (HMM).
An HMM has:
- M hidden states
{0, 1, ..., M-1} - N observation symbols
{0, 1, ..., N-1} - Initial distribution
pi[i]= P(S_0 = i) - Transition matrix
A[i][j]= P(S_{t+1} = j | S_t = i) - Emission matrix
B[i][k]= P(O_t = k | S_t = i) - Observation sequence
o[0..T-1]
Implement the following three functions
Forward Pass
def forward(pi, A, B, o):
# pi: list of length M
# A: list of lists, shape (M, M)
# B: list of lists, shape (M, N)
# o: list of length T
# returns alpha: list of lists, shape (M, T)
# alpha[i][t] = P(o_0, ..., o_t, S_t = i)
- Base case:
alpha[i][0] = pi[i] * B[i][o[0]] - Recursion:
alpha[j][t] = sum_i( alpha[i][t-1] * A[i][j] ) * B[j][o[t]]
Backward Pass
def backward(A, B, o):
# returns beta: list of lists, shape (M, T)
# beta[i][t] = P(o_{t+1}, ..., o_{T-1} | S_t = i)
- Base case:
beta[i][T-1] = 1 - Recursion:
beta[i][t] = sum_j( A[i][j] * B[j][o[t+1]] * beta[j][t+1] )
Gamma and Xi
def compute_gamma_xi(alpha, beta, A, B, o):
# returns (gamma, xi)
# gamma: list of lists, shape (M, T)
# xi: list of list of lists, shape (T-1, M, M)
gamma[i][t] = alpha[i][t] * beta[i][t] / sum_i( alpha[i][t] * beta[i][t] )xi[t][i][j] = alpha[i][t] * A[i][j] * B[j][o[t+1]] * beta[j][t+1], normalized sosum_{i,j} xi[t][i][j] = 1
Input Format
M N T
pi_0 pi_1 ... pi_{M-1}
A[0][0] ... A[0][M-1]
...
A[M-1][0] ... A[M-1][M-1]
B[0][0] ... B[0][N-1]
...
B[M-1][0] ... B[M-1][N-1]
o_0 o_1 ... o_{T-1}
Output Format
alpha: M rows of T floats (6 decimal places)
<blank line>
beta: M rows of T floats
<blank line>
gamma: M rows of T floats
<blank line>
xi: T-1 blocks of M rows of M floats, separated by blank lines
If T=1, xi is empty — print nothing after gamma.
Notes
- All rows of
Asum to 1. All rows ofBsum to 1.pisums to 1. - Observations are 0-indexed integers in
[0, N-1]. - No external libraries required.
Accepted 1/1
Acceptance 100%
Loading editor...
Input
2 2 3
0.6 0.4
0.7 0.3
0.4 0.6
0.5 0.5
0.4 0.6
0 1 0Expected Output
0.300000 0.137000 0.070270
0.160000 0.111600 0.043224
0.243700 0.470000 1.000000
0.252400 0.440000 1.000000
0.644175 0.567343 0.619152
0.355825 0.432657 0.380848
0.434825 0.209350
0.132518 0.223307
0.422489 0.144853
0.196662 0.235995