'''
The posterior distribution results from running inference for a given query using MCMC.
This module contains
* Data structures to collect posterior samples
* Convergence diagnostics
The convergence diagnostics have to be called by the ProbReM project script, e.g. from outside the framework.
'''
import logging
import numpy as N
import pylab as PL
from network.vertices import ReferenceVertex
from inference import engine
'''The engine module contains the :class:`.GBNgraph` instance
'''
samples = {}
'''
Dicitonary containing all the samples collected during inference. It is common to run more than one chain to montitor convergence, the key/value pairs are stored
{ key = 'chainIdentification' : value = :class:`numpy.array` }
'''
currentChain = None
'''
:class:`numpy.array` that is currently being used by the sampler
'''
posteriorVertices = None
'''
Dictionary of vertices that we are collecting samples from (e.g. the event vertices or all sampling vertices including latent variables), set
'''
currentIndex = {}
'''
Dictionary mapping each event variable `ID` to an index used to access :attr:`.currentChain`
'''
[docs]def initChain(chainID,ITER,onlyEvent=False):
'''
Initializes a new MCMC run. Note that `onlyEvent=False`, so the samples or all `engine.GBN.samplingVertices` are collected. But the posterior is a joint distribution over the event variables (thus the other sampling variables are already marginalized)
:arg chainID: String identification for new chain
:arg ITER: Number of samples to be collected
:arg onlyEvent: Bolean, if `True` only the values of the event vertices are collected (i.e. not latent sampling variables)
'''
global samples,currentChain,currentIndex,posteriorVertices
posteriorVertices = engine.GBN.samplingVertices
if onlyEvent:
posteriorVertices = engine.GBN.eventVertices
currentIndex = {}
for i,vertexID in enumerate(posteriorVertices.keys()):
currentIndex[vertexID] = i
nVariables = len(posteriorVertices)
samples[chainID] = N.zeros((ITER,nVariables))
currentChain = samples[chainID]
[docs]def collectSamples(nSample):
'''
Extracting the value of a node and storing it in the appropriate `numpy.array`, :attr:`.currentChain`.
:arg nSample: Int Count of the collected sample, i.e. the row number
'''
for gbnID,gbnV in posteriorVertices.items():
# If the vertex is a reference vertex, we collect the current reference
if isinstance(gbnV,ReferenceVertex):
i = currentIndex[gbnID]
# at this point we are assuming that k=1
# we can extract the ID using the 1st entry of the obj list of the attribute, e.g. 2 in e.g. 'Professor.fame.2'
currentChain[nSample,i] = gbnV.references.values()[0].obj[0]
# If the vertex is a normal vertex, we collect the sampled value
else:
i = currentIndex[gbnID]
currentChain[nSample,i] = gbnV.value
[docs]def plotCumulativeMeanAllChains(**kwargs):
'''
Plots the cumulative mean of all available chains using :meth:`.cumulativeMean`.
If the plots are to be displayed on the same figure, use the `figID` keyword.
If the plots are to display only a specific variable, use the `gbnV` or `varIndex` keyword.
:arg kwargs: Optional arguments for :meth:`.cumulativeMean`
'''
# create new figure to plot all chains in
fig = PL.figure()
for chainID in samples.keys():
plotCumulativeMean(fig=fig,chainID=chainID, **kwargs)
[docs]def plotCumulativeMean(**kwargs):
'''
Convergence diagnostics that plots the cumulative mean of all the sampling variables in the `currentChain`. If a `chainID` is provided the cumulative mean of the associated chain is plotted instead. If `sVarInd` is provided - e.g. the index of a sampling variable - only the cumulative mean of this variable is plotted.
:arg chainID: Optional identification of chain to be analyzed
:arg varIndex: Optional index of event variable to be analyzed
:arg gbnV: Optional :class:`GBNvertex` to be analyzed
:arg fig: Optional `matplotlib.figure.Figure` to be used
'''
chain = currentChain
if 'chainID' in kwargs:
chain = samples[kwargs['chainID']]
#calculate cumulative mean
cumChain = cumulativeMean(chain)
# either plot on a new or specific figure window
fig = None
if 'fig' in kwargs:
fig = kwargs['fig']
PL.figure(fig.number)
else:
fig = PL.figure()
if 'varIndex' in kwargs:
PL.plot(cumChain[:,kwargs['varIndex']])
PL.xlabel('Samples for index %s'%kwargs['varIndex'])
PL.ylabel('Mean')
elif 'gbnV' in kwargs:
gbnID = kwargs['gbnV'].ID
PL.plot(cumChain[:,currentIndex[gbnID]])
PL.xlabel('Samples for %s'%gbnID)
PL.ylabel('Mean')
else:
# plot cumulative for all posterior variables in subplots
for i,pID in enumerate(posteriorVertices.keys()):
PL.subplot(len(posteriorVertices),1,(i+1))
PL.plot(cumChain[:,currentIndex[pID]])
PL.xlabel('Samples for %s'%pID)
PL.ylabel('Mean')
[docs]def cumulativeMean(chain):
'''
Returns the cumulative mean of `chain`.
:arg chain: `numpy.array`
'''
cumChain = chain.cumsum(axis=0)
for i in range(cumChain.shape[0]):
cumChain[i,:] = cumChain[i,:]/(i+1.)
return cumChain
[docs]def mean(chainID=None, gbnV=None, sVarInd=None, combined=False):
'''
Returns the posterior mean of all the sampling variables in the currentChain.
If a `chainID` is provided the mean of the associated chain is returned instead. If `sVarInd` or `gbnV` is provided, only the mean of this variable is returned.
The :meth:`pylab.hist` method is used to compute the histogram
:arg chainID: Optional identification of chain to be analyzed
:arg sVarInd: Optional index of event variable to be analyzed
:arg gbnV: Optional :class:`GBNvertex` event variable to be analyzed
:arg combined: Optional (default `False`), if `True` the mean of the mean of all event variables is returned (single value).
:returns: Posterior mean as :class:`numpy.array`. If `sVarInd` or `gbnV` are specified, this is a single value
'''
chain = currentChain
if chainID is not None:
chain = samples[chainID]
#computing mean for one or all variables?
if combined:
return N.mean(chain)
elif sVarInd is not None:
return N.mean(chain[:,sVarInd])
elif gbnV is not None:
ind = currentIndex[gbnV]
return N.mean(chain[:,ind])
else:
return N.mean(chain,0)
[docs]def histogramm(**kwargs):
'''
Convergence diagnostics that plots the posterior density function (using the matplotlib histogram) mean of all the sampling variables in the currentChain.
If a `chainID` is provided the histogram of the associated chain is plotted instead. If `varIndex` or `gbnV` is provided, only the histogram of this variable is plotted.
:arg chainID: Optional identification of chain to be analyzed
:arg varIndex: Optional index of event variable to be analyzed
:arg gbnV: Optional :class:`GBNvertex` event variable to be analyzed
:arg fig: Optional `matplotlib.figure.Figure` to be used
'''
chain = currentChain
if 'chainID' in kwargs:
chain = samples[kwargs['chainID']]
# either plot on a new or specific figure window
fig = None
if 'fig' in kwargs:
fig = kwargs['fig']
PL.figure(fig.number)
else:
fig = PL.figure()
if 'varIndex' in kwargs:
PL.hist(chain[:,kwargs['varIndex']])
elif 'gbnV' in kwargs:
PL.hist(chain[:,currentIndex[kwargs['gbnV'].ID]])
else:
logging.info("WARNING: no posterior variable passed as argument for histogramm")
PL.hist(chain)
[docs]def gelman_rubin():
'''
Plots the Gelman Rubin convergence diagnostic, according to `Probabilistic Graphical Models` (p. 523).
'''
# extract the number of samples M
M = engine.inferenceAlgo.ITER
# number of discarded samples (burnin)
T = engine.inferenceAlgo.BURNIN
# number of chains
K = engine.inferenceAlgo.CHAINS
# cumulative mean of for all chains
# f^{bar}_{k} in koller book
f_bar_K = {}
for ID,chain in samples.items():
f_bar_k = cumulativeMean(chain)
f_bar_K[ID] = f_bar_k
# cumulative mean across all chains
# f^{bar} in koller book
f_bar = N.zeros(currentChain.shape,dtype=float)
for f_bar_k in f_bar_K.values():
f_bar += f_bar_k
f_bar /= K
# between chain variance
# B in koller book
B = N.zeros(currentChain.shape,dtype=float)
for f_bar_k in f_bar_K.values():
B += (f_bar_k - f_bar)**2
B *= M/(K-1)
# with-in chain variance
# W in koller book
W = N.zeros(currentChain.shape,dtype=float)
# With just one sample variance will be inf
# W[0,:] = N.inf
W[0,:] = 1
for ID,chain in samples.items():
for j in range(1,M):
W[j,:] += 1./(j)*((chain[0:(j+1),:]-f_bar_K[ID][j,:])**2).sum(axis=0)
# W += ((chain - f_bar_K[ID])**2).cumsum(axis=0)
W *= (1./K)
# An estimator that can be shown to overestimate the variance
# V in koller book
V = (M-1.)/M * W + 1./M * B
# For M to ininity, both V and W converge to the true variance of the estimate
# Measure of disagreement between the chains
# R_hat is koller book
R_hat = N.sqrt(V/W)
# print 'samples',samples.values()
# print 'f_bar_K',f_bar_K
# print 'f_bar',f_bar
# print 'B',B
# print 'W',W
# print 'V',V
# print 'R_hat',R_hat
PL.figure()
for i,pID in enumerate(posteriorVertices.keys()):
PL.subplot(len(posteriorVertices),1,(i+1))
PL.plot(R_hat[1:,currentIndex[pID]])
PL.xlabel(pID)
PL.ylabel('Gelman-Rubin')
return R_hat
[docs]def autocorrelation(max_l = 50, **kwargs):
'''
Plots the autocorrelation. Computed according to `Probabilistic Graphical Models` (p. 521).
:arg max_l: The autocorellation will be calculated up to lag `max_l` (default=50)
:arg chainID: Optional identification of chain to be analyzed
'''
chain = currentChain
if 'chainID' in kwargs:
chain = samples[kwargs['chainID']]
# extract the number of samples M
M = engine.inferenceAlgo.ITER
# The cumulative mean
E_bar = cumulativeMean(chain)
# The `cumulative` Variance (equation 12.27)
V = N.zeros(chain.shape,dtype=float)
# Variance with one sample would be infinite I guess
V[0,:] = N.inf
for j in range(1,M):
V[j,:] = 1./(j)*((chain[0:(j+1),:]-E_bar[j,:])**2).sum(axis=0)
# Note that so far we have calculated the cumulative values or the mean and variance, i.e. E_bar and V are M x |Y| matrices
# However, for the covariance and the autocorrelation we only use the mean/variance for all M samples, i.e. V[-1,:] and E_bar[-1,:]
# The covariance (equation 12.28)
# A max_l x |Y| matrix, column `l` is the covariance with lag `l`
C = N.zeros((max_l,chain.shape[1]),dtype=float)
# For l=0, the covariance is equal to the variane
C[0,:] = V[-1,:]
for l in range(1,max_l):
C[l,:] = 1./(M-l) * ((chain[0:(M-l),:]-E_bar[-1])*(chain[l:,:]-E_bar[-1])).sum(axis=0)
# autocorrelation
A = C / V[-1,:]
# print A
PL.figure()
for i,pID in enumerate(posteriorVertices.keys()):
PL.subplot(len(posteriorVertices),1,(i+1))
PL.bar(N.arange(max_l), A[:,currentIndex[pID]])
PL.xlabel(pID)
PL.ylabel('Autocorrelation')
def __marginalize(margVertices):
'''
TODO: Marginalize `margVertices` (which must be a subset of `posteriorVertices`) from the posterior
'''
def __repr__():
return 'Posterior distribution module'