Simple Neural Network
Flexible Input System
The baseline neural network models support flexible inputs. Rather than hardcoding which omic data type a model uses,
you configure cell_line_views and drug_views directly in the hyperparameters.yaml file.
By doing this, we have replaced the ChemBERTaNeuralNetwork whose only difference to the SimpleNeuralNetwork was
its usage of ChemBERTa embeddings instead of fingerprints as input.
Configuring the input views
The default SimpleNeuralNetwork configuration uses gene expression and fingerprints:
SimpleNeuralNetwork:
cell_line_views:
- gene_expression
drug_views:
- fingerprints
dropout_prob:
- 0.3
units_per_layer:
- - 32
- 16
- 8
- 4
...
To train the same SimpleNeuralNetwork with ChemBERTa embeddings instead, change drug_views:
SimpleNeuralNetwork:
cell_line_views:
- gene_expression
drug_views:
- drug_chemberta_embeddings
dropout_prob:
- 0.3
units_per_layer:
- - 32
- 16
- 8
- 4
...
For more, see the documentation of the sklearn models: Flexible Input System.
Simple Neural Network Model
Contains the SimpleNeuralNetwork model.
- class drevalpy.models.SimpleNeuralNetwork.simple_neural_network.SimpleNeuralNetwork
Bases:
DRPModelSimple Feedforward Neural Network model with dropout using only gene expression data.
- build_model(hyperparameters)
Builds the model from hyperparameters.
- Parameters:
hyperparameters (
dict) – includes units_per_layer and dropout_prob.
- cell_line_views = []
- drug_views = []
- early_stopping = True
- classmethod load(directory)
Load a trained SimpleNeuralNetwork instance from disk.
This includes:
model.pt: PyTorch state_dict of the trained model
hyperparameters.json: Dictionary with model hyperparameters
scaler.pkl: Fitted StandardScaler for gene expression features
- Parameters:
directory (
str) – Directory containing the saved model files- Return type:
- Returns:
An instance of SimpleNeuralNetwork with restored state
- Raises:
FileNotFoundError – if any required file is missing
- load_cell_line_features(data_path, dataset_name)
Loads the cell line features for a single-view neural network.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the cell line features
- load_drug_features(data_path, dataset_name)
Loads the drug features for a single-view neural network.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the drug features
- predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)
Predicts the response for the given input.
- Parameters:
cell_line_ids (
ndarray) – IDs of the cell lines to be predicteddrug_ids (
ndarray) – IDs of the drugs to be predictedcell_line_input (
FeatureDataset) – gene expression of the test datadrug_input (
FeatureDataset|None) – fingerprints of the test data
- Return type:
- Returns:
the predicted drug responses
- save(directory)
Save the trained model, hyperparameters, and gene expression scaler to the given directory.
This enables full reconstruction of the model using load.
Files saved:
model.pt: PyTorch state_dict of the trained model
hyperparameters.json: Dictionary containing all relevant model hyperparameters
scaler.pkl: Fitted StandardScaler for gene expression features
- train(output, cell_line_input, drug_input=None, output_earlystopping=None, model_checkpoint_dir='checkpoints')
First scales the gene expression data and trains the model.
The gene expression data is first arcsinh transformed. Afterward, the StandardScaler() is fitted on the training gene expression data only. Then, it transforms all gene expression data.
- Parameters:
output (
DrugResponseDataset) – training data associated with the response outputcell_line_input (
FeatureDataset) – cell line omics featuresdrug_input (
FeatureDataset|None) – drug omics featuresoutput_earlystopping (
DrugResponseDataset|None) – optional early stopping datasetmodel_checkpoint_dir (
str) – directory to save the model checkpoints
- Raises:
ValueError – if drug_input (fingerprints) is missing
- Return type:
Multi-OMICS Neural Network
Contains the baseline MultiViewNeuralNetwork model.
- class drevalpy.models.SimpleNeuralNetwork.multi_view_neural_network.MultiViewNeuralNetwork
Bases:
DRPModelSimple Feedforward Neural Network model with dropout using multiple omics data.
- build_model(hyperparameters)
Builds the model from hyperparameters.
The model is a simple feedforward neural network with dropout. The PCA is used to reduce the dimensionality of the methylation data.
- Parameters:
hyperparameters (
dict) – dictionary containing the hyperparameters units_per_layer, dropout_prob, and methylation_pca_components.
- cell_line_views = ['gene_expression', 'methylation', 'mutations', 'copy_number_variation_gistic']
- drug_views = ['fingerprints']
- early_stopping = True
- classmethod get_model_name()
Returns the model name.
- Return type:
- Returns:
MultiViewNeuralNetwork
- classmethod load(directory)
Load a trained MultiViewNeuralNetwork instance from disk.
Always required: model.pt, hyperparameters.json, metadata.json. Conditionally required: gene_scaler.pkl (if gene_expression in views), methylation_scaler.pkl and methylation_pca.pkl (if methylation in views).
- Parameters:
directory (
str) – Directory containing the saved model files- Return type:
- Returns:
Fully restored MultiViewNeuralNetwork instance
- Raises:
FileNotFoundError – if any required file is missing
- load_cell_line_features(data_path, dataset_name)
Loads the cell line features for a multi-view neural network.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the cell line omics features
- load_drug_features(data_path, dataset_name)
Load the drug features for a multi-view neural network.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the drug features
- predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)
Applies arcsinh + scaling to gene expression and scaling + PCA to methylation, then predicts.
- Parameters:
drug_ids (
ndarray) – drug identifierscell_line_ids (
ndarray) – cell line identifiersdrug_input (
FeatureDataset|None) – drug omics featurescell_line_input (
FeatureDataset) – cell line omics features
- Return type:
- Returns:
predicted response
- save(directory)
Save the trained model, hyperparameters, scalers, PCA object, and feature dimensions to disk.
Files always saved: model.pt, hyperparameters.json, metadata.json. Conditionally saved: gene_scaler.pkl (if gene_expression in views), methylation_scaler.pkl and methylation_pca.pkl (if methylation in views).
- train(output, cell_line_input, drug_input=None, output_earlystopping=None, model_checkpoint_dir='')
Fits the PCA and trains the model.
- Parameters:
output (
DrugResponseDataset) – training data associated with the response outputcell_line_input (
FeatureDataset) – cell line omics featuresdrug_input (
FeatureDataset|None) – drug omics featuresoutput_earlystopping (
DrugResponseDataset|None) – optional early stopping datasetmodel_checkpoint_dir (
str) – directory to save the model checkpoints
- Raises:
ValueError – if drug_input is missing
Model utils
Utility functions for the simple neural network models.
- class drevalpy.models.SimpleNeuralNetwork.utils.FeedForwardNetwork(hyperparameters, input_dim)
Bases:
RegressionMetricsMixin,LightningModuleFeed forward neural network for regression tasks with basic architecture.
- configure_optimizers()
Overwrites the configure_optimizers from the LightningModule.
- Return type:
Optimizer- Returns:
Adam optimizer
- fit(output_train, cell_line_input, drug_input, cell_line_views, drug_views, output_earlystopping=None, trainer_params=None, batch_size=32, patience=5, num_workers=2, model_checkpoint_dir='checkpoints', wandb_project=None)
Fits the model.
First, the data is loaded using a DataLoader. Then, the model is trained using the Lightning Trainer.
- Parameters:
output_train (
DrugResponseDataset) – Response values for trainingcell_line_input (
FeatureDataset) – Cell line featuresdrug_input (
FeatureDataset|None) – Drug featurescell_line_views (
list[str]) – Cell line info needed for this modeloutput_earlystopping (
DrugResponseDataset|None) – Response values for early stoppingtrainer_params (
dict|None) – custom parameters for the trainerbatch_size – batch size for the DataLoader, default is 32
patience – patience for early stopping, default is 5
num_workers (
int) – number of workers for the DataLoader, default is 2model_checkpoint_dir (
str) – directory to save the model checkpointswandb_project (
str|None) – optional wandb project name for logging. If provided, uses WandbLogger for PyTorch Lightning training.
- Raises:
ValueError – if drug_input is missing
- Return type:
- forward(x)
Forward pass of the model.
- Parameters:
x – input data
- Return type:
Tensor- Returns:
predicted response
- predict(x)
Predicts the response for the given input.
- training_step(batch)
Overwrites the training step from the LightningModule.
Does a forward pass, calculates the loss and logs the loss.
- Parameters:
batch – batch of data
- Returns:
loss
- validation_step(batch)
Overwrites the validation step from the LightningModule.
Does a forward pass, calculates the loss and logs the loss.
- Parameters:
batch – batch of data
- Returns:
loss
- class drevalpy.models.SimpleNeuralNetwork.utils.RegressionDataset(output, cell_line_input, drug_input, cell_line_views, drug_views)
Bases:
DatasetDataset for regression tasks for the data loader.
- Parameters:
output (DrugResponseDataset)
cell_line_input (FeatureDataset)
drug_input (FeatureDataset)