SuperFELTR
SuperFELTR Model
Contains the SuperFELTR model.
Regression extension of Super.FELT: supervised feature extraction learning using triplet loss for drug response prediction with multi-omics data. Very similar to MOLI. Differences:
In MOLI, encoders and the classifier were trained jointly. Super.FELT trains them independently
- MOLI was trained without feature selection (except for the Variance Threshold on the gene expression).
Super.FELT uses feature selection for all omics data.
The input remains the same: somatic mutation, copy number variation and gene expression data. Original authors of SuperFELT: Park, Soh & Lee. (2021, 10.1186/s12859-021-04146-z) Code adapted from their Github: https://github.com/DMCB-GIST/Super.FELT and Hauptmann et al. (2023, 10.1186/s12859-023-05166-7) https://github.com/kramerlab/Multi-Omics_analysis
- class drevalpy.models.SuperFELTR.superfeltr.SuperFELTR
Bases:
DRPModelRegression extension of Super.FELT.
- build_model(hyperparameters)
Builds the model from hyperparameters.
- Parameters:
hyperparameters – dictionary containing the hyperparameters for the model. Contain mini_batch, dropout_rate, weight_decay, out_dim_expr_encoder, out_dim_mutation_encoder, out_dim_cnv_encoder, epochs, variance thresholds for gene expression, mutation, and copy number variation, margin, and learning rate.
- Return type:
- cell_line_views = ['gene_expression', 'mutations', 'copy_number_variation_gistic']
- drug_views = []
- early_stopping = True
- is_single_drug_model = True
- load_cell_line_features(data_path, dataset_name)
Loads the cell line features: gene expression, mutations, and copy number variation.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the cell line gene expression features, mutations, and copy number variation
- load_drug_features(data_path, dataset_name)
Returns None, as drug features are not needed for SuperFELTR.
- Parameters:
- Return type:
- Returns:
None
- predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)
Predicts the drug response.
If there is no training data, NA is predicted. If there was not enough training data, predictions are made with the randomly initialized model.
- Parameters:
cell_line_ids (
ndarray) – cell line idsdrug_ids (
ndarray) – drug idscell_line_input (
FeatureDataset) – cell line omics featuresdrug_input (
FeatureDataset|None) – drug omics features, not needed
- Return type:
- Returns:
predicted drug response
- Raises:
ValueError – if drug_input is not None
- train(output, cell_line_input, drug_input=None, output_earlystopping=None, model_checkpoint_dir='superfeltr_checkpoints')
Does feature selection, trains the encoders sequentially, and then trains the regressor.
If there is not enough training data, the model is trained with random initialization, if there is no training data at all, the model is skipped and later on, NA is predicted.
- Parameters:
output (
DrugResponseDataset) – training data associated with the response outputcell_line_input (
FeatureDataset) – cell line omics featuresdrug_input (
FeatureDataset|None) – not needed, as it is a single drug modeloutput_earlystopping (
DrugResponseDataset|None) – optional early stopping datasetmodel_checkpoint_dir (
str) – not needed
- Raises:
ValueError – if drug_input is not None
- Return type:
Model utils
Utility functions for the SuperFELTR model.
- class drevalpy.models.SuperFELTR.utils.SuperFELTEncoder(input_size, hpams, omic_type, ranges)
Bases:
LightningModuleSuperFELT encoder definition for a single omic type, i.e., gene expression, mutation, or copy number variation.
Very similar to MOLIEncoder, but with BatchNorm1d before ReLU.
- Parameters:
- configure_optimizers()
Override the configure_optimizers method to use the Adam optimizer.
- Return type:
Optimizer- Returns:
Adam optimizer
- forward(x)
Forward pass of the SuperFELTEncoder.
- Parameters:
x (
Tensor) – input tensor- Return type:
Tensor- Returns:
encoded tensor
- training_step(batch, batch_idx)
Override the training_step method to compute the triplet loss.
- class drevalpy.models.SuperFELTR.utils.SuperFELTRegressor(input_size, hpams, encoders)
Bases:
RegressionMetricsMixin,LightningModuleSuperFELT regressor definition.
Very similar to SuperFELT classifier, but with a regression loss and without the last sigmoid layer.
- Parameters:
input_size (int)
encoders (tuple[SuperFELTEncoder, SuperFELTEncoder, SuperFELTEncoder])
- configure_optimizers()
Override the configure_optimizers method to use the Adagrad optimizer.
- Return type:
Optimizer- Returns:
Adagrad optimizer
- forward(x)
Forward pass of the SuperFELTRegressor.
- Parameters:
x (
Tensor) – input tensor- Return type:
Tensor- Returns:
predicted response
- predict(data_expr, data_mut, data_cnv)
Predicts the response for the given input.
- training_step(batch, batch_idx)
Override the training_step method to compute the regression loss.
- drevalpy.models.SuperFELTR.utils.train_superfeltr_model(model, hpams, output_train, cell_line_input, output_earlystopping=None, patience=5, model_checkpoint_dir='superfeltr_checkpoints', wandb_project=None)
Trains one encoder or the regressor.
First, the dataset and loaders are created. Then, the model is trained with the Lightning trainer.
- Parameters:
model (
SuperFELTEncoder|SuperFELTRegressor) – either one of the encoders or the regressorhpams (
dict[str,int|float|dict]) – hyperparameters for the modeloutput_train (
DrugResponseDataset) – response data for trainingcell_line_input (
FeatureDataset) – cell line omics featuresoutput_earlystopping (
DrugResponseDataset|None) – response data for early stoppingpatience (
int) – for early stopping, defaults to 5model_checkpoint_dir (
str) – directory to save the model checkpointswandb_project (
str|None) – optional wandb project name for logging. If provided, uses WandbLogger for PyTorch Lightning training.
- Return type:
ModelCheckpoint- Returns:
checkpoint callback with the best model
- Raises:
ValueError – if the epochs and mini_batch are not integers