PyMC, моделирующий иерархическую регрессию с неизвестными средними значениями и ковариациями

Модель

У меня есть следующая статистическая модель:

r_i ~ N(r | mu_i, sigma)

mu_i = w . Q_i

w ~ N(w | phi, Sigma)

prior(phi, Sigma) = NormalInvWishart(0, 1, k+1, I_k)

Где sigma известно.

Q_i и r_i (награда) наблюдаются.

В данном случае r_i и mu_i — скаляры, w — 40x1, Q_i — 1x40, phi — 40x1, Sigma — 40x40.

Версия в формате LaTeX: http://mathurl.com/m2utrz4

Код Python

Я пытаюсь создать модель PyMC, которая генерирует несколько образцов, а затем аппроксимирует phi и Sigma.

import pymc as pm
import numpy as np

SAMPLE_SIZE = 100
q_samples = ... # Q created elsewhere
reward_sigma = np.identity(SAMPLE_SIZE) * 0.1
phi_true = (np.random.rand(40)+1) * -2
sigma_true = np.random.rand(40, 40) * 2. - 1.
weights_true = np.random.multivariate_normal(phi_true, sigma_true)
reward_true = np.random.multivariate_normal(np.dot(q_samples,weights_true), reward_sigma)

with pm.Model() as model:
    phi = pm.MvNormal('phi', np.zeros((ndims)), np.identity((ndims)) * 2)
    sigma = pm.InverseWishart('sigma', ndims+1, np.identity(ndims))
    weights = pm.MvNormal('weights', phi, sigma)
    rewards = pm.Normal('rewards', np.dot(weights, q_samples), reward_sigma, observed=reward_true)

with model:
    start = pm.find_MAP()
    step = pm.NUTS()
    trace = pm.sample(3000, step, start)

pm.traceplot(trace)

Однако, когда я запускаю приложение, я получаю следующую ошибку:

Traceback (most recent call last):
  File "test_pymc.py", line 46, in <module>
    phi = pm.MvNormal('phi', np.zeros((ndims)), np.identity((ndims)) * 2)
TypeError: Wrong number of dimensions: expected 0, got 1 with shape (40,).

Я как-то неправильно настроил свою модель?


person Wesley Tansey    schedule 09.11.2013    source источник


Ответы (1)


Я думаю, вам не хватает параметра формы для MvNormal. Я думаю, что MvNormal(..., shape = ndim) должен решить проблему. Вероятно, нам следует придумать способ сделать вывод об этом лучше.

person John Salvatier    schedule 12.11.2013