StructuralModel¶
The deep_lvpm.models.StructuralModel.StructuralModel class is the core of the DLVPM toolbox. It inherits from tf.keras.Model and coordinates multiple measurement models (one per data view) together with a binary path model that specifies how latent factors are connected across views. During training, the model learns sets of orthogonal deep latent variables (DLVs) that maximise correlation between the outputs of the measurement models while respecting the path structure.
Parameters¶
The constructor has the following signature:
StructuralModel(
Path,
model_list,
regularizer_list,
tot_num,
ndims,
orthogonalization="Moore-Penrose",
momentum=0.95,
epsilon=1e-4
)
where:
Path (array-like): Binary adjacency matrix defining which latent factors are connected between data views. The shape of
Pathdetermines the number of latent factors.model_list (list of keras.Model): One measurement model per data view. Each model should accept a single input tensor and output a latent representation of arbitrary dimension.
regularizer_list (list): List of regulariser objects applied to the final projection layer of each measurement model. May be
Nonefor no regularisation.tot_num (int): Total number of samples used in training. This is used internally for scaling covariance matrices.
ndims (int): Number of orthogonal latent variables (DLVs) to extract.
orthogonalization (str, optional): Method for orthogonalising latent factors. Either
"Moore-Penrose"(default) or"zca". When set to"zca", theZCALayeris used instead ofFactorLayerfor orthogonalisation.momentum (float, optional): Momentum parameter for updating global statistics (default: 0.95).
epsilon (float, optional): Small constant added for numerical stability (default: 1e-4).
Attributes¶
A StructuralModel instance exposes several public attributes:
Path – the binary adjacency matrix.
model_list – list of measurement models.
regularizer_list – list of regularisers for projection layers.
tot_num – total number of samples.
ndims – number of latent variables.
loss_tracker_total – Keras metric tracking total loss during training.
corr_tracker – Keras metric tracking the average correlation between connected views.
loss_tracker_mse – Keras metric tracking mean squared error.
Common methods¶
Because StructuralModel subclasses tf.keras.Model, it supports the standard Keras interface for compilation, training and evaluation:
compile(optimizer_list)Configures the model for training. Unlike standard Keras models, this method requires a list of optimisers, one per measurement model. Example:
optimizer_list = [ tf.keras.optimizers.Adam(learning_rate=1e-4), tf.keras.optimizers.Adam(learning_rate=1e-3), tf.keras.optimizers.Adam(learning_rate=1e-4), ] struct_model.compile(optimizer_list)
fit(data, batch_size=None, epochs=1, ...)Trains the model on a list or generator of data arrays. The input
datashould be a list of arrays, one per view. Additional arguments (batch_size,epochs, callbacks, etc.) behave as in Keras.evaluate(data)Evaluates the model on input data and returns a list
[mse, correlation], wheremseis the mean squared error andcorrelationis the mean Pearson correlation between connected views.predict(data)Computes the deep latent variables for each view, returning a tensor of shape
(n_samples, ndims, n_views). To extract the latent variables for an individual view usestruct_model.model_list[i].predict(data[i]).calculate_corrmat(DLVs)Calculates correlation matrices for the latent variables produced by
predict. Returns a list of correlation matrices with lengthndims.
Internal methods¶
The following methods are used internally by the implementation and are generally not called directly:
add_DLVPM_layer – Adds a
FactorLayerorZCALayerto each measurement model, depending on the orthogonalisation method.call(inputs) – Forwards inputs through each measurement model and applies the latent projection.
train_step(inputs) / test_step(inputs) – Custom training and testing routines.
mse_loss(…) – Calculates the mean squared error loss.
corr_metric(…) – Calculates the average correlation metric.
get_config() and from_config() – Used for model serialisation.
get_compile_config() and compile_from_config() – Used by Keras to serialise optimizer configurations.