Visualize atom weight of AttentiveFP #DGL #RDKit #Chemoinformatics

Yesterday, I posted an example of DGL (almost same as original example code).

And I could make regression model with my own dataset. Fortunately DGL developer provides a code for visualize atom weights of trained model.

It means that, after building the model with AttentiveFP, you can visualize atom weight of the give molecule which means how much each atom contribute to the target value.

I saw many example of the approach but never tried it by myself. So I tried to it today.

Following code is same as yesterday.

%matplotlib inline 
import matplotlib.pyplot as plt
import os
from rdkit import Chem
from rdkit import RDPaths

import dgl
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from dgl import model_zoo

from dgl.data.chem.utils import mol_to_complete_graph, mol_to_bigraph

from dgl.data.chem.utils import atom_type_one_hot
from dgl.data.chem.utils import atom_degree_one_hot
from dgl.data.chem.utils import atom_formal_charge
from dgl.data.chem.utils import atom_num_radical_electrons
from dgl.data.chem.utils import atom_hybridization_one_hot
from dgl.data.chem.utils import atom_total_num_H_one_hot
from dgl.data.chem.utils import one_hot_encoding
from dgl.data.chem import CanonicalAtomFeaturizer
from dgl.data.chem import CanonicalBondFeaturizer
from dgl.data.chem import ConcatFeaturizer
from dgl.data.chem import BaseAtomFeaturizer
from dgl.data.chem import BaseBondFeaturizer

from dgl.data.chem import one_hot_encoding
from dgl.data.utils import split_dataset

from functools import partial
from sklearn.metrics import roc_auc_score
def chirality(atom):
    try:
        return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
               [atom.HasProp('_ChiralityPossible')]
    except:
        return [False, False] + [atom.HasProp('_ChiralityPossible')]
    
def collate_molgraphs(data):
    """Batching a list of datapoints for dataloader.
    Parameters
    ----------
    data : list of 3-tuples or 4-tuples.
        Each tuple is for a single datapoint, consisting of
        a SMILES, a DGLGraph, all-task labels and optionally
        a binary mask indicating the existence of labels.
    Returns
    -------
    smiles : list
        List of smiles
    bg : BatchedDGLGraph
        Batched DGLGraphs
    labels : Tensor of dtype float32 and shape (B, T)
        Batched datapoint labels. B is len(data) and
        T is the number of total tasks.
    masks : Tensor of dtype float32 and shape (B, T)
        Batched datapoint binary mask, indicating the
        existence of labels. If binary masks are not
        provided, return a tensor with ones.
    """
    assert len(data[0]) in [3, 4], \
        'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
    if len(data[0]) == 3:
        smiles, graphs, labels = map(list, zip(*data))
        masks = None
    else:
        smiles, graphs, labels, masks = map(list, zip(*data))

    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    labels = torch.stack(labels, dim=0)
    
    if masks is None:
        masks = torch.ones(labels.shape)
    else:
        masks = torch.stack(masks, dim=0)
    return smiles, bg, labels, masks

atom_featurizer = BaseAtomFeaturizer(
                 {'hv': ConcatFeaturizer([
                  partial(atom_type_one_hot, allowable_set=[
                          'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
                    encode_unknown=True),
                  partial(atom_degree_one_hot, allowable_set=list(range(6))),
                  atom_formal_charge, atom_num_radical_electrons,
                  partial(atom_hybridization_one_hot, encode_unknown=True),
                  lambda atom: [0], # A placeholder for aromatic information,
                    atom_total_num_H_one_hot, chirality
                 ],
                )})
bond_featurizer = BaseBondFeaturizer({
                                     'he': lambda bond: [0 for _ in range(10)]
    })

train=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.train.sdf')
test=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.test.sdf')

train_graph =[mol_to_bigraph(mol,
                           atom_featurizer=atom_featurizer, 
                           bond_featurizer=bond_featurizer) for mol in train_mols]

test_graph =[mol_to_bigraph(mol,
                           atom_featurizer=atom_featurizer, 
                           bond_featurizer=bond_featurizer) for mol in test_mols]

def run_a_train_epoch(n_epochs, epoch, model, data_loader,loss_criterion, optimizer):
    model.train()
    total_loss = 0
    losses = []
    
    for batch_id, batch_data in enumerate(data_loader):
        batch_data
        smiles, bg, labels, masks = batch_data
        if torch.cuda.is_available():
            bg.to(torch.device('cuda:0'))
            labels = labels.to('cuda:0')
            masks = masks.to('cuda:0')
        
        prediction = model(bg, bg.ndata['hv'], bg.edata['he'])
        loss = (loss_criterion(prediction, labels)*(masks != 0).float()).mean()
        #loss = loss_criterion(prediction, labels)
        #print(loss.shape)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.data.item())
        
    #total_score = np.mean(train_meter.compute_metric('rmse'))
    total_score = np.mean(losses)
    print('epoch {:d}/{:d}, training {:.4f}'.format( epoch + 1, n_epochs,  total_score))
    return total_score

model = model_zoo.chem.AttentiveFP(node_feat_size=39,
                                  edge_feat_size=10,
                                  num_layers=2,
                                  num_timesteps=2,
                                  graph_feat_size=200,
                                  output_size=1,
                                  dropout=0.2)
model = model.to('cuda:0')

train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)
test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)

loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)
n_epochs = 100
epochs = []
scores = []
for e in range(n_epochs):
    score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)
    epochs.append(e)
    scores.append(score)
model.eval()

OK, I build the predictive model. (of course build model can save with torch.save(model.state_dict(), PATH) method.)

Let’s visualize molecule with atom weights!

At first, import packages for molecule visualization.

import copy
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
from IPython.display import display
import matplotlib
import matplotlib.cm as cm

Then define visualization function. Following code is borrowed from original repository thanks a lot. DGL model has get_node_weight option, which returns node_weight of the graph. The model has two layers of GRU so timestep must be 0 or 1 following code I used 0 as timestep.

def drawmol(idx, dataset, timestep):
    smiles, graph, _ = dataset[idx]
    print(smiles)
    bg = dgl.batch([graph])
    atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he']
    if torch.cuda.is_available():
        print('use cuda')
        bg.to(torch.device('cuda:0'))
        atom_feats = atom_feats.to('cuda:0')
        bond_feats = bond_feats.to('cuda:0')
    
    _, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True)
    assert timestep < len(atom_weights), 'Unexpected id for the readout round'
    atom_weights = atom_weights[timestep]
    min_value = torch.min(atom_weights)
    max_value = torch.max(atom_weights)
    atom_weights = (atom_weights - min_value) / (max_value - min_value)
    
    norm = matplotlib.colors.Normalize(vmin=0, vmax=1.28)
    cmap = cm.get_cmap('bwr')
    plt_colors = cm.ScalarMappable(norm=norm, cmap=cmap)
    atom_colors = {i: plt_colors.to_rgba(atom_weights[i].data.item()) for i in range(bg.number_of_nodes())}

    mol = Chem.MolFromSmiles(smiles)
    rdDepictor.Compute2DCoords(mol)
    drawer = rdMolDraw2D.MolDraw2DSVG(280, 280)
    drawer.SetFontSize(1)
    op = drawer.drawOptions()
    
    mol = rdMolDraw2D.PrepareMolForDrawing(mol)
    drawer.DrawMolecule(mol, highlightAtoms=range(bg.number_of_nodes()),
                             highlightBonds=[],
                             highlightAtomColors=atom_colors)
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    svg = svg.replace('svg:', '')
    if torch.cuda.is_available():
        atom_weights = atom_weights.to('cpu')
    return (Chem.MolFromSmiles(smiles), atom_weights.data.numpy(), svg)

Draw test dataset molecules. The model predicts solubility and color indicates that red is positive effect for solubility and blue is negative impact.

target = test_loader.dataset
for i in range(len(target)):
    mol, aw, svg = drawmol(i, target, 0)
    display(SVG(svg))

For personally, hydroxyl group has positive effect for solubility I think but the model shows it is not always true. Hmm my code is something wrong? Or I need to think about more details of the model?
I would like to try more predictive task and write helper code for DGL’s AttentiveFP for convenient way to molecular visualization and model building.

Today’s whole code is uploaded below.
https://gist.github.com/iwatobipen/72a2d9dd616322f1f20469a152f2bb58

Any comments or suggestions will be highly appreciated. ;)

12 thoughts on “Visualize atom weight of AttentiveFP #DGL #RDKit #Chemoinformatics

  1. Great job, as always. For counter-intuitive attention weights, there are two possible explanations:
    1. After all, the models are playing with statistics and there can be a wrong correlation when the dataset cannot fully represent the underlying distribution of compounds in the chemical space.
    2. With mechanism like graph convolution, back propagation, etc., the node features not only encode the particular atoms, but also a local context.

    • Thanks for useful comment.
      If dominant reason is 1., huge amount of training dataset will be help I think . I used very small amount of data in the post.
      By considering the explanation of 2., the model learned its own concept of functional groups. It seems interesting for me. I’d like to try to use the approach other dataset.

  2. That’s a great post! As for your discussion, I have a question about how to do scaling of atomic weights. In your post, the raw weights were converted into the scaled ones in the same way as MinMaxScaler of sklearn. But if the raw atomic weight can take a negative value, the sign may be reversed. The sign (positive or negative; red or blue) is supposed to be important for the intuitive visualization. In this regard, I guess it could be better to do the scaling just with dividing each value by the maximum absolute weights. ( [w / max(abs(weights)) for w in weights]), and it may alleviate the counter-intuition a little in some cases.

    • Hi, thank you for your nice suggestion!
      I agree your opinion. And I checked source code. The weight which is returned by AttentiveFP was calculated with softmax function.(If I understood DGL correctly)
      So all returned values will be positive.
      It means that current scaling process will not change the sign.
      However, I’ll try your proposal later. Now I’m writing the comment from mobile phone…
      Thanks

  3. “The model has two layers of GRU so timestep must be 0 or 1 following code I used 0 as timestep.”?

    is “0 or 1” the GRU’s state? what will happen if you use 1 as timestep?

  4. Hi, iwatobipen-san, and Mufei-san,

    Thanks for your quick and valuable comments! I didn’t look at the implementation, but I got it that it’s by sigmoid/softmax to give 0-1 scores.

    I am still confused about how to distinguish “positive” and “negative” for each weight. It’s not the case only for this example, but also for some of others in which explainable methods are applied to QSAR models. I guess the followings may somehow address counter-intuition (The data size and/or algorithm itself also matters as you discussed).

    If the raw weights are calculated by sigmoid/softmax, how is the scaling by setting 0.5, the middle point of such functions, as the mean: like, [ (w – 0.5) / max(abs(w – 0.5)) for w in weights] ??
    Or,
    How about using Sequential Colormap rather than Diverging Coloarmap (see, https://matplotlib.org/tutorials/colors/colormaps.html), as used in the original implementation of DGL example (cmap was set by “cm.get_cmap(‘Oranges’)” in it )? By doing this, it can be considered only how large or small each weight is.

    Thanks in advance!

    • Thanks for your suggestion.
      I modified my code to print min-max value of each molecule in visualization process.
      URL is below.

      At the case of AttentiveFP, min, max values are often smaller than 0.5. So the scaling by setting 0.5 will not work.

      The middle point of atom weight depends on each molecule. The magnitude of the contribution against molecular properties of each atom is different on each molecule.
      I’m not sure that 0.5 is middle point of the weight in this case.
      I’ll check other GCN visualization code.

      • Many thanks! Your reply clearly answered what I want to know! Based on this results, my first suggestion doesn’t seem suitable for this. And I found that how to do the scaling would be the problem that needs to be considered further.

        It was a very useful discussion! Please continue to make great posts!

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.