Simple Neural Network
Simple Neural Network Model and ChemBERTaNeuralNetwork
Contains the SimpleNeuralNetwork and the ChemBERTaNeuralNetwork model.
- class drevalpy.models.SimpleNeuralNetwork.simple_neural_network.ChemBERTaNeuralNetwork
Bases:
SimpleNeuralNetworkChemBERTa Neural Network model using gene expression and ChemBERTa drug embeddings.
- drug_views = ['chemberta_embeddings']
- classmethod get_model_name()
Returns the model name.
- Return type:
- Returns:
ChemBERTaNeuralNetwork
- load_drug_features(data_path, dataset_name)
Loads the ChemBERTa embeddings.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the ChemBERTa embeddings
- Raises:
FileNotFoundError – if the ChemBERTa embeddings file is not found
- class drevalpy.models.SimpleNeuralNetwork.simple_neural_network.SimpleNeuralNetwork
Bases:
DRPModelSimple Feedforward Neural Network model with dropout using only gene expression data.
- build_model(hyperparameters)
Builds the model from hyperparameters.
- Parameters:
hyperparameters (
dict) – includes units_per_layer and dropout_prob.
- cell_line_views = ['gene_expression']
- drug_views = ['fingerprints']
- early_stopping = True
- classmethod load(directory)
Load a trained SimpleNeuralNetwork instance from disk.
This includes: - model.pt: PyTorch state_dict of the trained model - hyperparameters.json: Dictionary with model hyperparameters - scaler.pkl: Fitted StandardScaler for gene expression features
- Parameters:
directory (
str) – Directory containing the saved model files- Return type:
- Returns:
An instance of SimpleNeuralNetwork with restored state
- Raises:
FileNotFoundError – if any required file is missing
- load_cell_line_features(data_path, dataset_name)
Loads the cell line features.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the cell line gene expression features, filtered through the landmark genes
- load_drug_features(data_path, dataset_name)
Loads the fingerprint data.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the fingerprints
- predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)
Predicts the response for the given input.
- Parameters:
cell_line_ids (
ndarray) – IDs of the cell lines to be predicteddrug_ids (
ndarray) – IDs of the drugs to be predictedcell_line_input (
FeatureDataset) – gene expression of the test datadrug_input (
FeatureDataset|None) – fingerprints of the test data
- Return type:
- Returns:
the predicted drug responses
- save(directory)
Save the trained model, hyperparameters, and gene expression scaler to the given directory.
This enables full reconstruction of the model using load.
Files saved: - model.pt: PyTorch state_dict of the trained model - hyperparameters.json: Dictionary containing all relevant model hyperparameters - scaler.pkl: Fitted StandardScaler for gene expression features
- train(output, cell_line_input, drug_input=None, output_earlystopping=None, model_checkpoint_dir='checkpoints')
First scales the gene expression data and trains the model.
The gene expression data is first arcsinh transformed. Afterward, the StandardScaler() is fitted on the training gene expression data only. Then, it transforms all gene expression data. :type output:
DrugResponseDataset:param output: training data associated with the response output :type cell_line_input:FeatureDataset:param cell_line_input: cell line omics features :type drug_input:FeatureDataset|None:param drug_input: drug omics features :type output_earlystopping:DrugResponseDataset|None:param output_earlystopping: optional early stopping dataset :type model_checkpoint_dir:str:param model_checkpoint_dir: directory to save the model checkpoints :raises ValueError: if drug_input (fingerprints) is missing- Return type:
- Parameters:
output (DrugResponseDataset)
cell_line_input (FeatureDataset)
drug_input (FeatureDataset | None)
output_earlystopping (DrugResponseDataset | None)
model_checkpoint_dir (str)
Multi-OMICS Neural Network
Contains the baseline MultiOmicsNeuralNetwork model.
- class drevalpy.models.SimpleNeuralNetwork.multiomics_neural_network.MultiOmicsNeuralNetwork
Bases:
DRPModelSimple Feedforward Neural Network model with dropout using multiple omics data.
- build_model(hyperparameters)
Builds the model from hyperparameters.
The model is a simple feedforward neural network with dropout. The PCA is used to reduce the dimensionality of the methylation data.
- Parameters:
hyperparameters (
dict) – dictionary containing the hyperparameters units_per_layer, dropout_prob, and methylation_pca_components.
- cell_line_views = ['gene_expression', 'methylation', 'mutations', 'copy_number_variation_gistic']
- drug_views = ['fingerprints']
- early_stopping = True
- classmethod get_model_name()
Returns the model name.
- Return type:
- Returns:
MultiOmicsNeuralNetwork
- classmethod load(directory)
Load a trained MultiOmicsNeuralNetwork instance from disk.
Required files: - model.pt - hyperparameters.json - gene_scaler.pkl - methylation_scaler.pkl - methylation_pca.pkl - metadata.json
- Parameters:
directory (
str) – Directory containing the saved model files- Return type:
- Returns:
Fully restored MultiOmicsNeuralNetwork instance
- Raises:
FileNotFoundError – if any required file is missing
- load_cell_line_features(data_path, dataset_name)
Loads the cell line features.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the cell line omics features, filtered through the drug target genes
- load_drug_features(data_path, dataset_name)
Load the drug features.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the drug fingerprint features
- predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)
Applies arcsinh + scaling to gene expression and scaling + PCA to methylation, then predicts.
- Parameters:
drug_ids (
ndarray) – drug identifierscell_line_ids (
ndarray) – cell line identifiersdrug_input (
FeatureDataset|None) – drug omics featurescell_line_input (
FeatureDataset) – cell line omics features
- Return type:
- Returns:
predicted response
- save(directory)
Save the trained model, hyperparameters, scalers, PCA object, and feature dimensions to disk.
Files saved: - model.pt - hyperparameters.json - gene_scaler.pkl - methylation_scaler.pkl - methylation_pca.pkl - metadata.json
- train(output, cell_line_input, drug_input=None, output_earlystopping=None, model_checkpoint_dir='')
Fits the PCA and trains the model.
- Parameters:
output (
DrugResponseDataset) – training data associated with the response outputcell_line_input (
FeatureDataset) – cell line omics featuresdrug_input (
FeatureDataset|None) – drug omics featuresoutput_earlystopping (
DrugResponseDataset|None) – optional early stopping datasetmodel_checkpoint_dir (
str) – directory to save the model checkpoints
- Raises:
ValueError – if drug_input (fingerprints) is missing
Model utils
Utility functions for the simple neural network models.
- class drevalpy.models.SimpleNeuralNetwork.utils.FeedForwardNetwork(hyperparameters, input_dim)
Bases:
LightningModuleFeed forward neural network for regression tasks with basic architecture.
- configure_optimizers()
Overwrites the configure_optimizers from the LightningModule.
- Return type:
Optimizer- Returns:
Adam optimizer
- fit(output_train, cell_line_input, drug_input, cell_line_views, drug_views, output_earlystopping=None, trainer_params=None, batch_size=32, patience=5, num_workers=2, model_checkpoint_dir='checkpoints')
Fits the model.
First, the data is loaded using a DataLoader. Then, the model is trained using the Lightning Trainer. :type output_train:
DrugResponseDataset:param output_train: Response values for training :type cell_line_input:FeatureDataset:param cell_line_input: Cell line features :type drug_input:FeatureDataset|None:param drug_input: Drug features :type cell_line_views:list[str] :param cell_line_views: Cell line info needed for this model :type drug_views:list[str] :param drug_views: Drug info needed for this model :type output_earlystopping:DrugResponseDataset|None:param output_earlystopping: Response values for early stopping :type trainer_params:dict|None:param trainer_params: custom parameters for the trainer :type batch_size: :param batch_size: batch size for the DataLoader, default is 32 :type patience: :param patience: patience for early stopping, default is 5 :type num_workers:int:param num_workers: number of workers for the DataLoader, default is 2 :type model_checkpoint_dir:str:param model_checkpoint_dir: directory to save the model checkpoints :raises ValueError: if drug_input is missing- Return type:
- Parameters:
output_train (DrugResponseDataset)
cell_line_input (FeatureDataset)
drug_input (FeatureDataset | None)
output_earlystopping (DrugResponseDataset | None)
trainer_params (dict | None)
num_workers (int)
model_checkpoint_dir (str)
- forward(x)
Forward pass of the model.
- Parameters:
x – input data
- Return type:
Tensor- Returns:
predicted response
- predict(x)
Predicts the response for the given input.
- training_step(batch)
Overwrites the training step from the LightningModule.
Does a forward pass, calculates the loss and logs the loss. :type batch: :param batch: batch of data :returns: loss
- validation_step(batch)
Overwrites the validation step from the LightningModule.
Does a forward pass, calculates the loss and logs the loss. :type batch: :param batch: batch of data :returns: loss
- class drevalpy.models.SimpleNeuralNetwork.utils.RegressionDataset(output, cell_line_input, drug_input, cell_line_views, drug_views)
Bases:
DatasetDataset for regression tasks for the data loader.
- Parameters:
output (DrugResponseDataset)
cell_line_input (FeatureDataset)
drug_input (FeatureDataset)