- edited description
Changes for More Flexible Data Types
Right now, the ObsModel and DataObj formats really only support data that is N x D - some number of observations, N, that are all D dimensional. Many models use data that has different shapes. For example, relational models require data that is NxNxD and sufficient statistics that are D x K x K. Below are changes that (hopefully) address these problems:
ObsModel Changes
Most of the code in obsmodels relies on two things: the shape of the data being N x D, and the shape of the LP['resp']
/ sufficient statistics / E_log_soft_ev
being linear in K. The former of these is given by the DataObj's format, while the second is controlled by the allocation model.
As is, the ObsModels implicitly rely on the fact that Data.X is two-dimensional by using np.dot
for matrix operations. For a more general case, we can use np.tensordot
, which handles tensors of higher rank. In some cases, (e.g. calcLogSoftEv
), this will always run over a single dimension; however, in calcSummaryStats
, this needs to sum out all dimensions that are linear in N. This suggests creating a field Data.numLinearDims
.
In many places, the ObsModel requires knowledge of the sufficient statistics dimension. For example, in GaussObsModel.calcSummaryStats
, it needs to set dims=('K', 'D', 'D')
. The 'K' dimension may vary depending on the AllocModel; in the case of the MMSB, it will need to be ('K', 'K', 'D', 'D'). To allow for this, the AllocModel should contain a method specifySSDimensions
that returns a string specifying this (the parent class AllocModel.py
, should return 'K' as a default). This can be called from ObsModel.setupWithAllocModel
(which extends the purpose of this function).
As an example, here's how the code would change in BernObsModel.calcSummaryStats
, with comments noting the dimensions for the non-assortative MMSB:
Resp = LP['resp'] # N x N x K x K
X = Data.X # N x N x D
CountON = np.tensordot(Resp.T, X, axes = 2) # result is K x K x D
CountOFF = np.tensordot(Resp.T, 1-X, axes = 2)
SS.setField('Count1', CountON, dims = self.SSDims)
DataObj Changes
Existing Classes
Current DataObj classes only need to be changed to contain a Data.numLinearDims
field.
##Graph Data ##
This is in issue #32
Sparsity
As a side note, it seems like tensordot
doesn't work with sparse matricies, so we might need another solution when we move to storing graph data in sparse format.
Comments (2)
-
reporter -
repo owner Will, can you clean up this description so it accurately describes the changes you've made in PR #29? I think some bits have changed slightly, and I want to make sure it's all documented here.
- Log in to comment