Simple Neural Network

Flexible Input System

The baseline neural network models support flexible inputs. Rather than hardcoding which omic data type a model uses, you configure cell_line_views and drug_views directly in the hyperparameters.yaml file.

By doing this, we have replaced the ChemBERTaNeuralNetwork whose only difference to the SimpleNeuralNetwork was its usage of ChemBERTa embeddings instead of fingerprints as input.

Configuring the input views

The default SimpleNeuralNetwork configuration uses gene expression and fingerprints:

SimpleNeuralNetwork:
    cell_line_views:
        - gene_expression
    drug_views:
        - fingerprints
    dropout_prob:
        - 0.3
    units_per_layer:
        - - 32
          - 16
          - 8
          - 4
    ...

To train the same SimpleNeuralNetwork with ChemBERTa embeddings instead, change drug_views:

SimpleNeuralNetwork:
    cell_line_views:
        - gene_expression
    drug_views:
        - drug_chemberta_embeddings
    dropout_prob:
        - 0.3
    units_per_layer:
        - - 32
          - 16
          - 8
          - 4
    ...

For more, see the documentation of the sklearn models: Flexible Input System.

Simple Neural Network Model

Contains the SimpleNeuralNetwork model.

class drevalpy.models.SimpleNeuralNetwork.simple_neural_network.SimpleNeuralNetwork

Bases: DRPModel

Simple Feedforward Neural Network model with dropout using only gene expression data.

build_model(hyperparameters)

Builds the model from hyperparameters.

Parameters:

hyperparameters (dict) – includes units_per_layer and dropout_prob.

cell_line_views = []
drug_views = []
early_stopping = True
classmethod get_model_name()

Returns the model name.

Return type:

str

Returns:

SimpleNeuralNetwork

classmethod load(directory)

Load a trained SimpleNeuralNetwork instance from disk.

This includes:

  • model.pt: PyTorch state_dict of the trained model

  • hyperparameters.json: Dictionary with model hyperparameters

  • scaler.pkl: Fitted StandardScaler for gene expression features

Parameters:

directory (str) – Directory containing the saved model files

Return type:

SimpleNeuralNetwork

Returns:

An instance of SimpleNeuralNetwork with restored state

Raises:

FileNotFoundError – if any required file is missing

load_cell_line_features(data_path, dataset_name)

Loads the cell line features for a single-view neural network.

Parameters:
  • data_path (str) – Path to the data

  • dataset_name (str) – name of the dataset

Return type:

FeatureDataset

Returns:

FeatureDataset containing the cell line features

load_drug_features(data_path, dataset_name)

Loads the drug features for a single-view neural network.

Parameters:
  • data_path (str) – Path to the data

  • dataset_name (str) – name of the dataset

Return type:

FeatureDataset | None

Returns:

FeatureDataset containing the drug features

predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)

Predicts the response for the given input.

Parameters:
  • cell_line_ids (ndarray) – IDs of the cell lines to be predicted

  • drug_ids (ndarray) – IDs of the drugs to be predicted

  • cell_line_input (FeatureDataset) – gene expression of the test data

  • drug_input (FeatureDataset | None) – fingerprints of the test data

Return type:

ndarray

Returns:

the predicted drug responses

save(directory)

Save the trained model, hyperparameters, and gene expression scaler to the given directory.

This enables full reconstruction of the model using load.

Files saved:

  • model.pt: PyTorch state_dict of the trained model

  • hyperparameters.json: Dictionary containing all relevant model hyperparameters

  • scaler.pkl: Fitted StandardScaler for gene expression features

Parameters:

directory (str) – Target directory to store all model artifacts

Return type:

None

train(output, cell_line_input, drug_input=None, output_earlystopping=None, model_checkpoint_dir='checkpoints')

First scales the gene expression data and trains the model.

The gene expression data is first arcsinh transformed. Afterward, the StandardScaler() is fitted on the training gene expression data only. Then, it transforms all gene expression data.

Parameters:
Raises:

ValueError – if drug_input (fingerprints) is missing

Return type:

None

Multi-OMICS Neural Network

Contains the baseline MultiViewNeuralNetwork model.

class drevalpy.models.SimpleNeuralNetwork.multi_view_neural_network.MultiViewNeuralNetwork

Bases: DRPModel

Simple Feedforward Neural Network model with dropout using multiple omics data.

build_model(hyperparameters)

Builds the model from hyperparameters.

The model is a simple feedforward neural network with dropout. The PCA is used to reduce the dimensionality of the methylation data.

Parameters:

hyperparameters (dict) – dictionary containing the hyperparameters units_per_layer, dropout_prob, and methylation_pca_components.

cell_line_views = ['gene_expression', 'methylation', 'mutations', 'copy_number_variation_gistic']
drug_views = ['fingerprints']
early_stopping = True
classmethod get_model_name()

Returns the model name.

Return type:

str

Returns:

MultiViewNeuralNetwork

classmethod load(directory)

Load a trained MultiViewNeuralNetwork instance from disk.

Always required: model.pt, hyperparameters.json, metadata.json. Conditionally required: gene_scaler.pkl (if gene_expression in views), methylation_scaler.pkl and methylation_pca.pkl (if methylation in views).

Parameters:

directory (str) – Directory containing the saved model files

Return type:

MultiViewNeuralNetwork

Returns:

Fully restored MultiViewNeuralNetwork instance

Raises:

FileNotFoundError – if any required file is missing

load_cell_line_features(data_path, dataset_name)

Loads the cell line features for a multi-view neural network.

Parameters:
  • data_path (str) – data path e.g. data/

  • dataset_name (str) – dataset name e.g. GDSC1

Return type:

FeatureDataset

Returns:

FeatureDataset containing the cell line omics features

load_drug_features(data_path, dataset_name)

Load the drug features for a multi-view neural network.

Parameters:
  • data_path (str) – path to the drug features, e.g., data/

  • dataset_name (str) – name of the dataset, e.g., GDSC1

Return type:

FeatureDataset | None

Returns:

FeatureDataset containing the drug features

predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)

Applies arcsinh + scaling to gene expression and scaling + PCA to methylation, then predicts.

Parameters:
Return type:

ndarray

Returns:

predicted response

save(directory)

Save the trained model, hyperparameters, scalers, PCA object, and feature dimensions to disk.

Files always saved: model.pt, hyperparameters.json, metadata.json. Conditionally saved: gene_scaler.pkl (if gene_expression in views), methylation_scaler.pkl and methylation_pca.pkl (if methylation in views).

Parameters:

directory (str) – Target directory

Return type:

None

train(output, cell_line_input, drug_input=None, output_earlystopping=None, model_checkpoint_dir='')

Fits the PCA and trains the model.

Parameters:
Raises:

ValueError – if drug_input is missing

Model utils

Utility functions for the simple neural network models.

class drevalpy.models.SimpleNeuralNetwork.utils.FeedForwardNetwork(hyperparameters, input_dim)

Bases: RegressionMetricsMixin, LightningModule

Feed forward neural network for regression tasks with basic architecture.

Parameters:
configure_optimizers()

Overwrites the configure_optimizers from the LightningModule.

Return type:

Optimizer

Returns:

Adam optimizer

fit(output_train, cell_line_input, drug_input, cell_line_views, drug_views, output_earlystopping=None, trainer_params=None, batch_size=32, patience=5, num_workers=2, model_checkpoint_dir='checkpoints', wandb_project=None)

Fits the model.

First, the data is loaded using a DataLoader. Then, the model is trained using the Lightning Trainer.

Parameters:
  • output_train (DrugResponseDataset) – Response values for training

  • cell_line_input (FeatureDataset) – Cell line features

  • drug_input (FeatureDataset | None) – Drug features

  • cell_line_views (list[str]) – Cell line info needed for this model

  • drug_views (list[str]) – Drug info needed for this model

  • output_earlystopping (DrugResponseDataset | None) – Response values for early stopping

  • trainer_params (dict | None) – custom parameters for the trainer

  • batch_size – batch size for the DataLoader, default is 32

  • patience – patience for early stopping, default is 5

  • num_workers (int) – number of workers for the DataLoader, default is 2

  • model_checkpoint_dir (str) – directory to save the model checkpoints

  • wandb_project (str | None) – optional wandb project name for logging. If provided, uses WandbLogger for PyTorch Lightning training.

Raises:

ValueError – if drug_input is missing

Return type:

None

forward(x)

Forward pass of the model.

Parameters:

x – input data

Return type:

Tensor

Returns:

predicted response

predict(x)

Predicts the response for the given input.

Parameters:

x (ndarray) – input data

Return type:

ndarray

Returns:

predicted response

training_step(batch)

Overwrites the training step from the LightningModule.

Does a forward pass, calculates the loss and logs the loss.

Parameters:

batch – batch of data

Returns:

loss

validation_step(batch)

Overwrites the validation step from the LightningModule.

Does a forward pass, calculates the loss and logs the loss.

Parameters:

batch – batch of data

Returns:

loss

class drevalpy.models.SimpleNeuralNetwork.utils.RegressionDataset(output, cell_line_input, drug_input, cell_line_views, drug_views)

Bases: Dataset

Dataset for regression tasks for the data loader.

Parameters: