One step forward, two steps back

10s 256 MB
Exclusive Normal+ 4 AlgorithmsMathMachine Learning Anthropic

Implement the forward-backward algorithm for a Hidden Markov Model (HMM).

An HMM has:

  • M hidden states {0, 1, ..., M-1}
  • N observation symbols {0, 1, ..., N-1}
  • Initial distribution pi[i] = P(S_0 = i)
  • Transition matrix A[i][j] = P(S_{t+1} = j | S_t = i)
  • Emission matrix B[i][k] = P(O_t = k | S_t = i)
  • Observation sequence o[0..T-1]

Implement the following three functions

Forward Pass

def forward(pi, A, B, o):
    # pi: list of length M
    # A:  list of lists, shape (M, M)
    # B:  list of lists, shape (M, N)
    # o:  list of length T
    # returns alpha: list of lists, shape (M, T)
    # alpha[i][t] = P(o_0, ..., o_t, S_t = i)

  • Base case: alpha[i][0] = pi[i] * B[i][o[0]]
  • Recursion: alpha[j][t] = sum_i( alpha[i][t-1] * A[i][j] ) * B[j][o[t]]

Backward Pass

def backward(A, B, o):
    # returns beta: list of lists, shape (M, T)
    # beta[i][t] = P(o_{t+1}, ..., o_{T-1} | S_t = i)

  • Base case: beta[i][T-1] = 1
  • Recursion: beta[i][t] = sum_j( A[i][j] * B[j][o[t+1]] * beta[j][t+1] )

Gamma and Xi

def compute_gamma_xi(alpha, beta, A, B, o):
    # returns (gamma, xi)
    # gamma: list of lists, shape (M, T)
    # xi:    list of list of lists, shape (T-1, M, M)

  • gamma[i][t] = alpha[i][t] * beta[i][t] / sum_i( alpha[i][t] * beta[i][t] )
  • xi[t][i][j] = alpha[i][t] * A[i][j] * B[j][o[t+1]] * beta[j][t+1], normalized so sum_{i,j} xi[t][i][j] = 1

Input Format

M N T
pi_0 pi_1 ... pi_{M-1}
A[0][0] ... A[0][M-1]
...
A[M-1][0] ... A[M-1][M-1]
B[0][0] ... B[0][N-1]
...
B[M-1][0] ... B[M-1][N-1]
o_0 o_1 ... o_{T-1}

Output Format

alpha: M rows of T floats (6 decimal places)
<blank line>
beta: M rows of T floats
<blank line>
gamma: M rows of T floats
<blank line>
xi: T-1 blocks of M rows of M floats, separated by blank lines

If T=1, xi is empty — print nothing after gamma.


Notes

  • All rows of A sum to 1. All rows of B sum to 1. pi sums to 1.
  • Observations are 0-indexed integers in [0, N-1].
  • No external libraries required.
Accepted 1/1
Acceptance 100%
Loading editor...
Input
2 2 3 0.6 0.4 0.7 0.3 0.4 0.6 0.5 0.5 0.4 0.6 0 1 0
Expected Output
0.300000 0.137000 0.070270 0.160000 0.111600 0.043224 0.243700 0.470000 1.000000 0.252400 0.440000 1.000000 0.644175 0.567343 0.619152 0.355825 0.432657 0.380848 0.434825 0.209350 0.132518 0.223307 0.422489 0.144853 0.196662 0.235995