DrugGNN

DrugGNN Model

DrugGNN model.

class drevalpy.models.DrugGNN.drug_gnn.DrugGNN

Bases: DRPModel

DrugGNN model.

build_model(hyperparameters)

Build the model.

Parameters:

hyperparameters (dict[str, Any]) – The hyperparameters.

Return type:

None

property cell_line_views: list[str]

Return the sources the model needs as input for describing the cell line.

Returns:

The sources the model needs as input for describing the cell line.

property drug_views: list[str]

Return the sources the model needs as input for describing the drug.

Returns:

The sources the model needs as input for describing the drug.

classmethod get_model_name()

Return the name of the model.

Return type:

str

Returns:

The name of the model.

load_cell_line_features(data_path, dataset_name)

Loads the cell line features.

Parameters:
  • data_path (str) – Path to the gene expression and landmark genes

  • dataset_name (str) – name of the dataset

Return type:

FeatureDataset

Returns:

FeatureDataset containing the cell line gene expression features.

load_drug_features(data_path, dataset_name)

Loads the pre-computed drug graph data.

Parameters:
  • data_path (str) – Path to the data directory.

  • dataset_name (str) – Name of the dataset.

Raises:
Return type:

FeatureDataset

Returns:

FeatureDataset containing the drug graphs.

load_model(path, drug_name=None)

Load the model.

Parameters:
  • path (str | Path) – The path to load the model from.

  • drug_name – The name of the drug.

predict(cell_line_ids, drug_ids, cell_line_input, drug_input=None)

Predict drug response.

Parameters:
Raises:
Return type:

ndarray

Returns:

The predicted drug response.

save_model(path, drug_name=None)

Save the model.

Parameters:
  • path (str | Path) – The path to save the model to.

  • drug_name – The name of the drug.

Raises:

RuntimeError – If there is no model to save.

train(output, cell_line_input, drug_input=None, output_earlystopping=None, **kwargs)

Train the model.

Parameters:
Raises:

ValueError – If drug input is not provided.

class drevalpy.models.DrugGNN.drug_gnn.DrugGNNModule(num_node_features, num_cell_features, hidden_dim=64, dropout=0.2, learning_rate=0.001)

Bases: RegressionMetricsMixin, LightningModule

The LightningModule for the DrugGNN model.

Parameters:
  • num_node_features (int)

  • num_cell_features (int)

  • hidden_dim (int)

  • dropout (float)

  • learning_rate (float)

configure_optimizers()

Configure the optimizer.

Returns:

The optimizer.

forward(batch)

Forward pass of the module.

Parameters:

batch – The batch.

Returns:

The output of the model.

predict_step(batch, batch_idx, dataloader_idx=0)

A single prediction step.

Parameters:
  • batch – The batch.

  • batch_idx – The batch index.

  • dataloader_idx – The dataloader index.

Returns:

The output of the model.

training_step(batch, batch_idx)

A single training step.

Parameters:
  • batch – The batch.

  • batch_idx – The batch index.

Returns:

The loss.

validation_step(batch, batch_idx)

A single validation step.

Parameters:
  • batch – The batch.

  • batch_idx – The batch index.

class drevalpy.models.DrugGNN.drug_gnn.DrugGraphNet(num_node_features, num_cell_features, hidden_dim=64, dropout=0.2)

Bases: Module

Neural network for DrugGNN.

forward(drug_graph, cell_features)

Forward pass of the network.

Parameters:
  • drug_graph – The drug graph.

  • cell_features – The cell line features.

Returns:

The output of the network.