One step forward, two steps back

10s 256 MB
Exclusive Normal+ 4
AlgorithmsMath +1 Anthropic

Implement the forward-backward algorithm for a Hidden Markov Model (HMM).

An HMM has:

  • M hidden states {0, 1, ..., M-1}
  • N observation symbols {0, 1, ..., N-1}
  • Initial distribution pi[i] = P(S_0 = i)
  • Transition matrix A[i][j] = P(S_{t+1} = j | S_t = i)
  • Emission matrix B[i][k] = P(O_t = k | S_t = i)
  • Observation sequence o[0..T-1]

Forward Pass

Python
def forward(pi, A, B, o):
    # pi: list of length M
    # A:  list of lists, shape (M, M)
    # B:  list of lists, shape (M, N)
    # o:  list of length T
    # returns alpha: list of lists, shape (M, T)
    # alpha[i][t] = P(o_0, ..., o_t, S_t = i)
  • Base case: alpha[i][0] = pi[i] * B[i][o[0]]
  • Recursion: alpha[j][t] = sum_i( alpha[i][t-1] * A[i][j] ) * B[j][o[t]]

Backward Pass

Python
def backward(A, B, o):
    # returns beta: list of lists, shape (M, T)
    # beta[i][t] = P(o_{t+1}, ..., o_{T-1} | S_t = i)
  • Base case: beta[i][T-1] = 1
  • Recursion: beta[i][t] = sum_j( A[i][j] * B[j][o[t+1]] * beta[j][t+1] )

Gamma and Xi

Python
def compute_gamma_xi(alpha, beta, A, B, o):
    # returns (gamma, xi)
    # gamma: list of lists, shape (M, T)
    # xi:    list of list of lists, shape (T-1, M, M)
  • gamma[i][t] = alpha[i][t] * beta[i][t] / sum_i( alpha[i][t] * beta[i][t] )
  • xi[t][i][j] = alpha[i][t] * A[i][j] * B[j][o[t+1]] * beta[j][t+1], normalized so sum_{i,j} xi[t][i][j] = 1
Text
M N T
pi_0 pi_1 ... pi_{M-1}
A[0][0] ... A[0][M-1]
...
A[M-1][0] ... A[M-1][M-1]
B[0][0] ... B[0][N-1]
...
B[M-1][0] ... B[M-1][N-1]
o_0 o_1 ... o_{T-1}
Text
alpha: M rows of T floats (6 decimal places)
<blank line>
beta: M rows of T floats
<blank line>
gamma: M rows of T floats
<blank line>
xi: T-1 blocks of M rows of M floats, separated by blank lines

If T=1, xi is empty — print nothing after gamma.

  • All rows of A sum to 1. All rows of B sum to 1. pi sums to 1.
  • Observations are 0-indexed integers in [0, N-1].
  • No external libraries required.
Accepted 4/5
Acceptance 80%
Loading editor...
Input
2 2 3
0.6 0.4
0.7 0.3
0.4 0.6
0.5 0.5
0.4 0.6
0 1 0
Expected Output
0.300000 0.137000 0.070270
0.160000 0.111600 0.043224

0.243700 0.470000 1.000000
0.252400 0.440000 1.000000

0.644175 0.567343 0.619152
0.355825 0.432657 0.380848

0.434825 0.209350
0.132518 0.223307

0.422489 0.144853
0.196662 0.235995