TCGA Lung Cancer Tutorial (PyTorch Backend)

This tutorial mirrors TCGA Lung Cancer Tutorial but runs entirely on the PyTorch backend provided by Keras 3. The code is copied from deep_lvpm.tutorial.tutorial_tcga_torch without modification, so you can follow along step by step. The same five TCGA modalities (histology, RNA‑seq, methylation, miRNA, and SNV) are processed with residual measurement models and connected through a five-factor structural path model.

Prerequisites

Install deep_lvpm with the torch-* extras described on the Installation page and set KERAS_BACKEND=torch before launching the tutorial. The bundled sample datasets under deep_lvpm.data are used, so no external downloads are required.

1. Load the TCGA dataset and initialise the Torch backend

We begin by forcing the PyTorch backend, importing all dependencies, and loading the multi-omics training arrays.

####### Tutorial 2 #########

import os
os.environ.setdefault("KERAS_BACKEND", "torch")

import numpy as np
import torch
import torch.nn as nn
import keras
from keras import optimizers, regularizers
from importlib import resources

import deep_lvpm as DLVPM
from deep_lvpm.model import StructuralModel


with resources.as_file(resources.files("deep_lvpm.data") /
                    "Lung_multiomics_sample_train.npz") as f:
    arrays = np.load(f)
    rnaseq      = arrays["rnaseq"]
    snv         = arrays["snv"]
    methylation = arrays["methylation"]
    mirna       = arrays["mirna"]
    histo20     = arrays["histo20"]

X_arr = [histo20, rnaseq, methylation, mirna, snv]   # preserve this order!

2. Define PyTorch measurement modules

Measurement models are authored in pure torch.nn. TorchModuleLayer adapts each module so it can run inside Keras, handling device placement and regularisation losses automatically. A small residual block mirrors the TensorFlow tutorial.

class TorchModuleLayer(keras.layers.Layer):
    """Wrap a PyTorch nn.Module for execution inside the Keras graph."""

    def __init__(self, torch_module: nn.Module, **kwargs):
        super().__init__(**kwargs)
        self.torch_module = torch_module
        self._feature_dim: int | None = None
        self._current_device: str | None = None

    def _flatten(self, tensor: torch.Tensor) -> torch.Tensor:
        return torch.flatten(tensor, start_dim=1) if tensor.ndim > 2 else tensor

    def build(self, input_shape):
        device = torch.device("cpu")
        self.torch_module.to(device)
        feature_dim = int(input_shape[-1])
        dummy = torch.zeros((2, feature_dim), dtype=torch.float32, device=device)
        with torch.no_grad():
            features = self._flatten(self.torch_module(dummy))
        self._feature_dim = int(features.shape[-1])
        self._current_device = device.type
        super().build(input_shape)

    def call(self, inputs, training=False):
        tensor = torch.as_tensor(inputs, dtype=torch.float32)
        device = tensor.device
        if device.type == "meta":
            batch = tensor.shape[0]
            feat_dim = self._feature_dim or 1
            return torch.zeros((batch, feat_dim), dtype=tensor.dtype, device=device)
        if self._current_device != device.type:
            self.torch_module.to(device)
            self._current_device = device.type
        self.torch_module.train(bool(training))
        tensor = tensor.to(device)
        features = self._flatten(self.torch_module(tensor))
        if hasattr(self.torch_module, "regularization_loss"):
            penalty = self.torch_module.regularization_loss()
            if penalty is not None:
                self.add_loss(penalty)
        return features


class ResidualBlockModule(nn.Module):
    """Fully-connected residual block mirroring the TensorFlow implementation."""

    def __init__(
        self,
        input_dim: int,
        kernel_reg_l1: float = 0.01,
        kernel_reg_l2: float = 0.01,
        dropout_rate: float = 0.5,
    ) -> None:
        super().__init__()
        self.kernel_reg_l1 = kernel_reg_l1
        self.kernel_reg_l2 = kernel_reg_l2
        self.linear1 = nn.Linear(input_dim, input_dim)
        self.bn = nn.BatchNorm1d(input_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(input_dim, input_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self._init_weights()

    def _init_weights(self) -> None:
        nn.init.eye_(self.linear1.weight)
        nn.init.eye_(self.linear2.weight)
        nn.init.zeros_(self.linear1.bias)
        nn.init.zeros_(self.linear2.bias)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        x = self.linear1(inputs)
        x = self.bn(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = x + inputs
        x = self.dropout(x)
        return x

    def regularization_loss(self) -> torch.Tensor:
        penalty = torch.zeros((), dtype=self.linear1.weight.dtype, device=self.linear1.weight.device)
        if self.kernel_reg_l1:
            penalty = penalty + self.kernel_reg_l1 * (
                torch.sum(torch.abs(self.linear1.weight)) + torch.sum(torch.abs(self.linear2.weight))
            )
        if self.kernel_reg_l2:
            penalty = penalty + self.kernel_reg_l2 * (
                torch.sum(self.linear1.weight ** 2) + torch.sum(self.linear2.weight ** 2)
            )
        return penalty


def residual_block(
    input_dim: int,
    kernel_reg_l1: float = 0.01,
    kernel_reg_l2: float = 0.01,
    dropout_rate: float = 0.5,
    name: str = "residual_block"
) -> keras.Model:
    inputs = keras.Input(shape=(input_dim,), name=f"{name}_in")
    module = ResidualBlockModule(
        input_dim=input_dim,
        kernel_reg_l1=kernel_reg_l1,
        kernel_reg_l2=kernel_reg_l2,
        dropout_rate=dropout_rate,
    )
    outputs = TorchModuleLayer(module, name=f"{name}_torch")(inputs)
    return keras.Model(inputs=inputs, outputs=outputs, name=name)


model_list = [
    residual_block(histo20.shape[1], name="histo20_enc"),
    residual_block(rnaseq.shape[1],  name="rnaseq_enc"),
    residual_block(methylation.shape[1], name="meth_enc"),
    residual_block(mirna.shape[1],   name="mirna_enc"),
    residual_block(snv.shape[1],     name="snv_enc"),
]

3. Define the structural path and schedules

We keep the same five-factor adjacency matrix and learning-rate schedule used in the TensorFlow tutorial. tot_num records the sample count for internal normalisation.

ndims = 5        # number of latent factors

Path = np.array(
    [
        [0, 1, 0, 0, 0],
        [1, 0, 1, 1, 1],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
    ],
    dtype="float32",
)

batch_size  = 256
epochs      = 300
total_steps = int(rnaseq.shape[0] / batch_size) * epochs

init_lr, final_lr = 1e-4, 1e-5

lr_schedule = optimizers.schedules.ExponentialDecay(
    initial_learning_rate=init_lr,
    decay_steps=total_steps,
    decay_rate=final_lr / init_lr,
    staircase=False
)

tot_num = rnaseq.shape[0] ## This is the total number of samples under analysis and is needed by DLVPM

4. Build and compile the StructuralModel

The StructuralModel receives the path matrix, measurement models, projection regularisers, and the standard optimiser list—one Adam instance per view.

regularizer_list = [regularizers.L1L2(l1=0.001, l2=0.001),regularizers.L1L2(l1=0.001, l2=0.001),regularizers.L1L2(l1=0.001, l2=0.001),regularizers.L1L2(l1=0.001, l2=0.001),regularizers.L1L2(l1=0.001, l2=0.001)] ## These regularizers are applied to the final "projection" layer of the DLVPM model, used internally

DLVPM_Structural_instance = StructuralModel(Path, model_list, regularizer_list, tot_num, ndims, momentum=0.95,epsilon=0.001, orthogonalization='Moore-Penrose', train_DLV =True)

opt_list = [keras.optimizers.Adam(learning_rate=lr_schedule),keras.optimizers.Adam(learning_rate=lr_schedule),keras.optimizers.Adam(learning_rate=lr_schedule),keras.optimizers.Adam(learning_rate=lr_schedule),keras.optimizers.Adam(learning_rate=lr_schedule)]
DLVPM_Structural_instance.compile(optimizer=opt_list)

5. Train and evaluate on the training cohort

Training uses the usual fit call, and evaluate reports the mean correlation metric across connected modalities.

DLVPM_Structural_instance.fit(X_arr, batch_size=batch_size, epochs=epochs,verbose=True)
mean_corr = DLVPM_Structural_instance.evaluate(X_arr)

print('The mean correlation between data-types connected by the path model is r=' + str(mean_corr[1]))

6. Evaluate and analyse the test cohort

Load the held-out test arrays, compute the evaluation metrics, and inspect the latent correlations via predict.

with resources.as_file(resources.files("deep_lvpm.data") /
                    "Lung_multiomics_sample_test.npz") as f:
    arrays = np.load(f)
    rnaseq_test      = arrays["rnaseq"]
    snv_test         = arrays["snv"]
    methylation_test = arrays["methylation"]
    mirna_test       = arrays["mirna"]
    histo20_test     = arrays["histo20"]

X_arr_test = [histo20_test, rnaseq_test, methylation_test, mirna_test, snv_test]   # Here, is the full test dataset list
mean_corr_test = DLVPM_Structural_instance.evaluate(X_arr_test)

print('The mean correlation between data-types connected by the path model is r=' + str(mean_corr_test[1]))

test_DLVs = DLVPM_Structural_instance.predict(X_arr_test) ## Here, we obtain the full set of test_DLVs

## Associations between the first set of DLVs are:
print(np.corrcoef(test_DLVs[:,0,:].T))
## Associations between the second set of DLVs are:
print(np.corrcoef(test_DLVs[:,1,:].T))

7. Visualise cross‑modal correlations

DLVPM can summarise cross‑modal relationships per latent factor as chord diagrams. For each DLV, we compute a correlation matrix across modalities and render one panel. Nodes are the data types; edge thickness and opacity scale with |r|; edges below min_corr are hidden. Optionally, label edges with their correlation values.

# One correlation matrix per latent factor (DLV1, DLV2, ...)
corr_mat = DLVPM_Structural_instance.calculate_corrmat(test_DLVs)

from deep_lvpm.plot import plot_correlation_chord_row

data_names = ["Histology", "RNASeq", "miRNASeq", "Methylation", "SNVs"]

fig, ax = plot_correlation_chord_row(
    corr_mat,
    data_names,
    min_corr=0.0,
    node_cmap_name="Pastel1",
    figure_title=(
        "Correlation Plots Between Omics and Imaging Data Types in Lung Cancer"
    ),
    show_edge_labels=True,
    dpi=300,
    show=True,
)

Increase min_corr (e.g. 0.2) to focus on the strongest links.

Row of chord diagrams showing cross‑modal correlations for each DLV.

Example output from plot_correlation_chord_row. Each panel corresponds to a latent factor (DLV). Nodes are modalities; edge thickness/opacity encode |r| between modalities for that DLV. Labels can be toggled with show_edge_labels.

8. Save the PyTorch-backed StructuralModel

Persist the trained StructuralModel using the .keras format (which preserves both TensorFlow and PyTorch weights + custom layers).

DLVPM_Structural_instance.save("/path/to/output_folder/DLVPM_Model.keras")

The PyTorch backend delivers parity with the TensorFlow example while letting you integrate PyTorch-native measurement modules into StructuralModel.