Visualize atom weight of AttentiveFP #DGL #RDKit #Chemoinformatics

Yesterday, I posted an example of DGL (almost same as original example code).

And I could make regression model with my own dataset. Fortunately DGL developer provides a code for visualize atom weights of trained model.

It means that, after building the model with AttentiveFP, you can visualize atom weight of the give molecule which means how much each atom contribute to the target value.

I saw many example of the approach but never tried it by myself. So I tried to it today.

Following code is same as yesterday.

%matplotlib inline 
import matplotlib.pyplot as plt
import os
from rdkit import Chem
from rdkit import RDPaths

import dgl
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from import DataLoader
from import Dataset
from dgl import model_zoo

from import mol_to_complete_graph, mol_to_bigraph

from import atom_type_one_hot
from import atom_degree_one_hot
from import atom_formal_charge
from import atom_num_radical_electrons
from import atom_hybridization_one_hot
from import atom_total_num_H_one_hot
from import one_hot_encoding
from import CanonicalAtomFeaturizer
from import CanonicalBondFeaturizer
from import ConcatFeaturizer
from import BaseAtomFeaturizer
from import BaseBondFeaturizer

from import one_hot_encoding
from import split_dataset

from functools import partial
from sklearn.metrics import roc_auc_score
def chirality(atom):
        return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
        return [False, False] + [atom.HasProp('_ChiralityPossible')]
def collate_molgraphs(data):
    """Batching a list of datapoints for dataloader.
    data : list of 3-tuples or 4-tuples.
        Each tuple is for a single datapoint, consisting of
        a SMILES, a DGLGraph, all-task labels and optionally
        a binary mask indicating the existence of labels.
    smiles : list
        List of smiles
    bg : BatchedDGLGraph
        Batched DGLGraphs
    labels : Tensor of dtype float32 and shape (B, T)
        Batched datapoint labels. B is len(data) and
        T is the number of total tasks.
    masks : Tensor of dtype float32 and shape (B, T)
        Batched datapoint binary mask, indicating the
        existence of labels. If binary masks are not
        provided, return a tensor with ones.
    assert len(data[0]) in [3, 4], \
        'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
    if len(data[0]) == 3:
        smiles, graphs, labels = map(list, zip(*data))
        masks = None
        smiles, graphs, labels, masks = map(list, zip(*data))

    bg = dgl.batch(graphs)
    labels = torch.stack(labels, dim=0)
    if masks is None:
        masks = torch.ones(labels.shape)
        masks = torch.stack(masks, dim=0)
    return smiles, bg, labels, masks

atom_featurizer = BaseAtomFeaturizer(
                 {'hv': ConcatFeaturizer([
                  partial(atom_type_one_hot, allowable_set=[
                          'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
                  partial(atom_degree_one_hot, allowable_set=list(range(6))),
                  atom_formal_charge, atom_num_radical_electrons,
                  partial(atom_hybridization_one_hot, encode_unknown=True),
                  lambda atom: [0], # A placeholder for aromatic information,
                    atom_total_num_H_one_hot, chirality
bond_featurizer = BaseBondFeaturizer({
                                     'he': lambda bond: [0 for _ in range(10)]

train=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.train.sdf')
test=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.test.sdf')

train_graph =[mol_to_bigraph(mol,
                           bond_featurizer=bond_featurizer) for mol in train_mols]

test_graph =[mol_to_bigraph(mol,
                           bond_featurizer=bond_featurizer) for mol in test_mols]

def run_a_train_epoch(n_epochs, epoch, model, data_loader,loss_criterion, optimizer):
    total_loss = 0
    losses = []
    for batch_id, batch_data in enumerate(data_loader):
        smiles, bg, labels, masks = batch_data
        if torch.cuda.is_available():
            labels ='cuda:0')
            masks ='cuda:0')
        prediction = model(bg, bg.ndata['hv'], bg.edata['he'])
        loss = (loss_criterion(prediction, labels)*(masks != 0).float()).mean()
        #loss = loss_criterion(prediction, labels)
    #total_score = np.mean(train_meter.compute_metric('rmse'))
    total_score = np.mean(losses)
    print('epoch {:d}/{:d}, training {:.4f}'.format( epoch + 1, n_epochs,  total_score))
    return total_score

model = model_zoo.chem.AttentiveFP(node_feat_size=39,
model ='cuda:0')

train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)
test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)

loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)
n_epochs = 100
epochs = []
scores = []
for e in range(n_epochs):
    score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)

OK, I build the predictive model. (of course build model can save with, PATH) method.)

Let’s visualize molecule with atom weights!

At first, import packages for molecule visualization.

import copy
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
from IPython.display import display
import matplotlib
import as cm

Then define visualization function. Following code is borrowed from original repository thanks a lot. DGL model has get_node_weight option, which returns node_weight of the graph. The model has two layers of GRU so timestep must be 0 or 1 following code I used 0 as timestep.

def drawmol(idx, dataset, timestep):
    smiles, graph, _ = dataset[idx]
    bg = dgl.batch([graph])
    atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he']
    if torch.cuda.is_available():
        print('use cuda')'cuda:0'))
        atom_feats ='cuda:0')
        bond_feats ='cuda:0')
    _, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True)
    assert timestep < len(atom_weights), 'Unexpected id for the readout round'
    atom_weights = atom_weights[timestep]
    min_value = torch.min(atom_weights)
    max_value = torch.max(atom_weights)
    atom_weights = (atom_weights - min_value) / (max_value - min_value)
    norm = matplotlib.colors.Normalize(vmin=0, vmax=1.28)
    cmap = cm.get_cmap('bwr')
    plt_colors = cm.ScalarMappable(norm=norm, cmap=cmap)
    atom_colors = {i: plt_colors.to_rgba(atom_weights[i].data.item()) for i in range(bg.number_of_nodes())}

    mol = Chem.MolFromSmiles(smiles)
    drawer = rdMolDraw2D.MolDraw2DSVG(280, 280)
    op = drawer.drawOptions()
    mol = rdMolDraw2D.PrepareMolForDrawing(mol)
    drawer.DrawMolecule(mol, highlightAtoms=range(bg.number_of_nodes()),
    svg = drawer.GetDrawingText()
    svg = svg.replace('svg:', '')
    if torch.cuda.is_available():
        atom_weights ='cpu')
    return (Chem.MolFromSmiles(smiles),, svg)

Draw test dataset molecules. The model predicts solubility and color indicates that red is positive effect for solubility and blue is negative impact.

target = test_loader.dataset
for i in range(len(target)):
    mol, aw, svg = drawmol(i, target, 0)

For personally, hydroxyl group has positive effect for solubility I think but the model shows it is not always true. Hmm my code is something wrong? Or I need to think about more details of the model?
I would like to try more predictive task and write helper code for DGL’s AttentiveFP for convenient way to molecular visualization and model building.

Today’s whole code is uploaded below.

Any comments or suggestions will be highly appreciated. ;)