import pystan
schools_code = """
data {
int<lower=0> J; // number of schools
real y[J]; // estimated treatment effects
real<lower=0> sigma[J]; // s.e. of effect estimates
}
parameters {
real mu;
real<lower=0> tau;
real eta[J];
}
transformed parameters {
real theta[J];
for (j in 1:J)
theta[j] <- mu + tau * eta[j];
}
model {
eta ~ normal(0, 1);
y ~ normal(theta, sigma);
}
"""
schools_dat = {'J': 8,
'y': [28, 8, -3, 7, -1, 1, 18, 12],
'sigma': [15, 10, 16, 11, 9, 11, 10, 18]}
fit = pystan.stan(model_code=schools_code, data=schools_dat,
iter=1000, chains=4)
No handlers could be found for logger "pystan"
la = fit.extract(permuted=True) # return a dictionary of arrays
la
OrderedDict([(u'mu', array([ 5.59201329, 3.08293794, 12.65459799, ..., 7.68515371,
8.59399361, 14.47784153])), (u'tau', array([ 6.88494701, 7.61015808, 14.59393773, ..., 12.40867374,
3.35666338, 11.76037016])), (u'eta', array([[ 0.94089472, 0.13742766, -1.04342986, ..., 0.41533048,
1.29960003, -0.07172436],
[ 1.19378693, 0.32648325, -0.33514159, ..., -1.12628397,
1.18576523, 0.37595963],
[-0.46136505, -0.24074896, -0.29887854, ..., -0.04012823,
0.44861805, 0.77169001],
...,
[-0.54190956, 0.54506108, -0.44790532, ..., -0.33965885,
1.07906142, 0.10106834],
[ 0.99142511, -1.33269486, -1.29537437, ..., -0.32046492,
-0.15912437, 0.38890699],
[-0.59123533, -0.77392162, -0.75164679, ..., -0.10952102,
0.09420464, 0.00843435]])), (u'theta', array([[ 12.07002359, 6.53819544, -1.59194598, ..., 8.45154164,
14.53969066, 5.09819485],
[ 12.16784522, 5.5675271 , 0.53245747, ..., -5.48826113,
12.1067988 , 5.94405015],
[ 5.9214652 , 9.14112265, 8.29278314, ..., 12.06896917,
19.20170188, 23.9165939 ],
...,
[ 0.96077475, 14.44863878, 2.12724272, ..., 3.47043785,
21.07487483, 8.93927776],
[ 11.92187397, 4.12058559, 4.24585789, ..., 7.51830075,
8.05986668, 9.89942347],
[ 7.52469517, 5.37623681, 5.638197 , ..., 13.18983377,
15.58572302, 14.57703258]])), (u'lp__', array([-2.00196315, -1.47962994, -0.40126404, ..., -0.93404769,
-4.01935861, -1.7305003 ]))])
mu = la['mu']
mu
array([ 5.59201329, 3.08293794, 12.65459799, ..., 7.68515371,
8.59399361, 14.47784153])
fit.extract(permuted=False)
array([[[ 7.64153978, 22.22394329, 1.22547384, ..., 15.8277953 ,
-4.66356829, 0.23601896],
[ 9.69206299, 4.51897046, 0.66467932, ..., 11.73638593,
7.16792315, -3.06426551],
[ 0.87311424, 3.47002334, 0.8279629 , ..., 0.75643056,
-0.65091325, -8.20712019],
[ 7.60612766, 6.08568668, -0.47210517, ..., 14.54647964,
7.9102236 , -2.3997052 ]],
[[ 7.1309547 , 13.22644053, 0.947816 , ..., 12.45636118,
4.64533869, -0.78013228],
[ 5.24790718, 3.86378093, 0.23809979, ..., 7.99318893,
4.87367233, -2.79103881],
[ 5.26958099, 0.53395251, 0.53201011, ..., 5.46670294,
5.45398181, -4.18992926],
[ 2.00709596, 7.25825218, -0.27890103, ..., 8.46003823,
-2.76369558, -2.64511885]],
[[ 3.45046664, 12.63457236, 1.59298591, ..., 9.40716089,
-8.90763649, -3.00141685],
[ 3.21451455, 0.64483227, 0.66650154, ..., 3.44470382,
2.78355408, -5.9641561 ],
[ 10.32681882, 1.20833251, 0.57854296, ..., 11.11219035,
11.40421097, -3.9523716 ],
[ 5.71235948, 5.45514844, 1.12256322, ..., 7.04985609,
12.67242769, -3.53900179]],
...,
[[ 3.98155107, 1.13343685, -0.08562872, ..., 3.95352988,
4.10183286, -4.81207076],
[ 13.12535449, 5.70772091, 0.4030741 , ..., 10.95159301,
6.1646765 , -2.7487087 ],
[ 1.55174187, 5.92715919, 0.97098247, ..., 8.27351275,
-8.49507556, -4.79388894],
[ -1.07560713, 13.12929512, 0.52583857, ..., 8.12608869,
5.72010266, -2.28703102]],
[[ 8.28559892, 2.10766142, -0.64881074, ..., 7.36932785,
6.73054133, -3.35423052],
[ 14.40647332, 13.79326056, 0.78960029, ..., 6.44071296,
-0.19217228, -3.39933739],
[ 5.9038614 , 3.17080736, -0.51326664, ..., 11.02356889,
3.54807196, -5.65863866],
[ 7.98216011, 7.96351349, 2.1485957 , ..., 14.33779353,
6.13692258, -2.11563483]],
[[ 12.58510847, 8.86150948, 1.41560914, ..., 19.40481756,
19.50804561, -2.64456137],
[ 10.28052645, 5.15719855, -0.69619472, ..., 11.66938527,
13.43038767, -3.64024232],
[ 8.8819304 , 7.55656516, 2.07610963, ..., 24.25059819,
14.40756251, -8.04721041],
[ 5.32330057, 5.24879104, 0.43790708, ..., 11.40543276,
10.52205485, -2.75427205]]])
fit.plot()