DrugGNN
DrugGNN Model
DrugGNN model.
- class drevalpy.models.DrugGNN.drug_gnn.DrugGNN
Bases:
DRPModelDrugGNN model.
- build_model(hyperparameters)
Build the model.
- property cell_line_views: list[str]
Return the sources the model needs as input for describing the cell line.
- Returns:
The sources the model needs as input for describing the cell line.
- property drug_views: list[str]
Return the sources the model needs as input for describing the drug.
- Returns:
The sources the model needs as input for describing the drug.
- classmethod get_model_name()
Return the name of the model.
- Return type:
- Returns:
The name of the model.
- load_cell_line_features(data_path, dataset_name)
Loads the cell line features.
- Parameters:
- Return type:
- Returns:
FeatureDataset containing the cell line gene expression features.
- load_drug_features(data_path, dataset_name)
Loads the pre-computed drug graph data.
- Parameters:
- Raises:
FileNotFoundError – If the drug graph directory is not found.
ValueError – If no drug graphs are loaded.
- Return type:
- Returns:
FeatureDataset containing the drug graphs.
- load_model(path, drug_name=None)
Load the model.
- predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)
Predict drug response.
- Parameters:
cell_line_ids (
ndarray) – The cell line IDs.drug_ids (
ndarray) – The drug IDs.cell_line_input (
FeatureDataset) – The cell line input dataset.drug_input (
FeatureDataset|None) – The drug input dataset.
- Raises:
RuntimeError – If the model has not been trained yet.
ValueError – If drug input is not provided.
- Return type:
- Returns:
The predicted drug response.
- save_model(path, drug_name=None)
Save the model.
- Parameters:
- Raises:
RuntimeError – If there is no model to save.
- train(output, cell_line_input, drug_input=None, output_earlystopping=None, **kwargs)
Train the model.
- Parameters:
output (
DrugResponseDataset) – The output dataset.cell_line_input (
FeatureDataset) – The cell line input dataset.drug_input (
FeatureDataset|None) – The drug input dataset.output_earlystopping (
DrugResponseDataset|None) – The early stopping output dataset.kwargs – Additional arguments.
- Raises:
ValueError – If drug input is not provided.
- class drevalpy.models.DrugGNN.drug_gnn.DrugGNNModule(num_node_features, num_cell_features, hidden_dim=64, dropout=0.2, learning_rate=0.001)
Bases:
RegressionMetricsMixin,LightningModuleThe LightningModule for the DrugGNN model.
- Parameters:
- configure_optimizers()
Configure the optimizer.
- Returns:
The optimizer.
- forward(batch)
Forward pass of the module.
- Parameters:
batch – The batch.
- Returns:
The output of the model.
- predict_step(batch, batch_idx, dataloader_idx=0)
A single prediction step.
- Parameters:
batch – The batch.
batch_idx – The batch index.
dataloader_idx – The dataloader index.
- Returns:
The output of the model.
- training_step(batch, batch_idx)
A single training step.
- Parameters:
batch – The batch.
batch_idx – The batch index.
- Returns:
The loss.
- validation_step(batch, batch_idx)
A single validation step.
- Parameters:
batch – The batch.
batch_idx – The batch index.
- class drevalpy.models.DrugGNN.drug_gnn.DrugGraphNet(num_node_features, num_cell_features, hidden_dim=64, dropout=0.2)
Bases:
ModuleNeural network for DrugGNN.
- forward(drug_graph, cell_features)
Forward pass of the network.
- Parameters:
drug_graph – The drug graph.
cell_features – The cell line features.
- Returns:
The output of the network.