Skip to content

shahram-boshra/ne_class

Repository files navigation

Node/edge-level Molecular Graph Neural Network (MGNN) for Molecular Property Classification Prediction

This repository contains the code for a Molecular Graph Neural Network (MGNN) designed for classifying both node-level and edge-level of molecules. It utilizes RDKit for molecular processing and PyTorch Geometric for graph representation and learning.

Table of Contents

Introduction

This project aims to provide a robust and flexible framework for molecular property classification prediction using graph neural networks. It addresses the prediction of both atom-level and bond-level properties classification, which are crucial in various cheminformatics tasks. The MGNN architecture combines Graph Convolutional Networks (GCNs) and linear layers to efficiently learn from molecular graphs.

Installation

  1. Clone the repository:

    git clone https://github.com/shahram-boshra/ne_class.git (or git@github.com:shahram-boshra/ne_class.git)
    cd  ne_reg
  2. Create a virtual environment (recommended):

    python -m venv venv
    source venv/bin/activate  # On Linux/macOS
    venv\Scripts\activate  # On Windows
  3. Install the required dependencies:

    pip install -r requirements.txt

    Note: Ensure you have PyTorch and PyTorch Geometric installed correctly, as they might require specific CUDA configurations depending on your system.

Usage

  1. Prepare your molecular data:

    • Ensure your molecular data is in a format that RDKit can process (e.g., SMILES, SDF).
    • Create CSV files for node targets and edge targets, with appropriate indexing (molecule name, atom/bond index).
  2. Run the training script:

    python main.py --config config.yaml
    • Modify the config.yaml file to adjust hyperparameters, data paths, and other settings.
  3. Evaluate the model:

    • After training, the script will output performance metrics and generate plots showing training and validation progress.
    • You can also modify the main.py script to perform further evaluations or predictions on new data.

File Structure

mgnn-molecular-prediction/ ├── config.yaml # Configuration file for training ├── data/ # Directory for storing molecular data ├── exceptions.py # Custom exception classes ├── feature_encoder.py # Class for encoding atom and bond features ├── graph_utils.py # Utility functions for graph processing ├── main.py # Main script for training and evaluation ├── metrics.py # Class for calculating and plotting metrics ├── model.py # Definition of the MGNN model ├── trainer.py # Class for training and validation logic ├── earlystopping.py # Class for early stopping ├── requirements.txt # List of required Python packages └── README.md # This file

Model Architecture

The MGModel architecture consists of:

  • GCN Blocks: Two GCN layers (GCNConv) for learning node representations.
  • Linear Blocks: Linear layers with batch normalization and dropout for processing node and edge features.
  • Output Layers: Linear layers for predicting node and edge properties.

The model takes node features, edge indices, and edge features as input and outputs predictions for node and edge targets.

class MGModel(nn.Module):
    # ... (model definition)

Metrics
The metrics.py file provides functions to calculate and plot various performance metrics:

Mean Absolute Error (MAE)
Mean Squared Error (MSE)
R-squared (R2)
Pearson correlation coefficient
These metrics are used to evaluate the model's performance on the validation and test datasets.

Python

class Metrics:
    # ... (metrics calculation and plotting)

Training
The trainer.py file contains the ModelTrainer class, which handles the training and validation logic:

Training Epoch: train_epoch method trains the model for one epoch.
Validation Epoch: validate_epoch method validates the model's performance.
Testing: test method tests the model's performance.
Early Stopping: Early stopping is implemented to prevent overfitting.

Python

class ModelTrainer:
    # ... (training and validation logic)
Dependencies
torch
torch-geometric
rdkit-pypi
numpy
pandas
matplotlib
pyyaml
You can install these dependencies using the requirements.txt file.

Contributing
Contributions are welcome! If you find any issues or have suggestions for improvements, please open an issue or submit a pull request.   

1.  Fork the repository.
2.  Create a new branch for your feature or bug fix.
3.  Make your changes and commit them.
4.  Submit a pull request.   

License
This project is licensed under the MIT License.