Attentive FP with PyG #RDKit #PyG #pytorch_geometric #Chemoinformatics

As you know PyG is one of the useful package for graph based neural network as same as DGL-lifesci.

Fortunately recent version of PyG is easy to install because it supports conda. So to install PyG, user don’t need to install related package such as pytorch_scatter, pytorch-cluster etc. etc.

And PyG has lots of predefined models the list of them is listed in original document. AttentiveFP which is model for molecular representation learning is one of them. I wrote post about attentivefp with DGL before so I tried to use PyG attentiveFP today.

An example of attentivefp is provided from original repo. However the example uses torch_geometric.datasets.MoleculeNet class for data preparation so the available dataset is limited for dataset from MoleculeNet. I would like to local dataset with the model.

To do it, I modified original code and tried it. The code is below. Following example, I used Esol data which is downloaded from molecule net but downloaded file before running the code. I defined Molecule class which load local csv and process it. The difference of original MoleculeNet class is that the class don’t download data from web but the data comes from local file.

# most of code is came from original PyG repo
# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/attentive_fp.py

import os.path as osp
from math import sqrt

import torch
import torch.nn.functional as F
from rdkit import Chem

from torch_geometric.loader import DataLoader
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn.models import AttentiveFP
import torch

class GenFeatures(object):
    def __init__(self):
        self.symbols = [
            'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br',
            'Te', 'I', 'At', 'other'
        ]

        self.hybridizations = [
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D,
            Chem.rdchem.HybridizationType.SP3D2,
            'other',
        ]

        self.stereos = [
            Chem.rdchem.BondStereo.STEREONONE,
            Chem.rdchem.BondStereo.STEREOANY,
            Chem.rdchem.BondStereo.STEREOZ,
            Chem.rdchem.BondStereo.STEREOE,
        ]

    def __call__(self, data):
        # Generate AttentiveFP features according to Table 1.
        mol = Chem.MolFromSmiles(data.smiles)

        xs = []
        for atom in mol.GetAtoms():
            symbol = [0.] * len(self.symbols)
            symbol[self.symbols.index(atom.GetSymbol())] = 1.
            degree = [0.] * 6
            degree[atom.GetDegree()] = 1.
            formal_charge = atom.GetFormalCharge()
            radical_electrons = atom.GetNumRadicalElectrons()
            hybridization = [0.] * len(self.hybridizations)
            hybridization[self.hybridizations.index(
                atom.GetHybridization())] = 1.
            aromaticity = 1. if atom.GetIsAromatic() else 0.
            hydrogens = [0.] * 5
            hydrogens[atom.GetTotalNumHs()] = 1.
            chirality = 1. if atom.HasProp('_ChiralityPossible') else 0.
            chirality_type = [0.] * 2
            if atom.HasProp('_CIPCode'):
                chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1.

            x = torch.tensor(symbol + degree + [formal_charge] +
                             [radical_electrons] + hybridization +
                             [aromaticity] + hydrogens + [chirality] +
                             chirality_type)
            xs.append(x)

        data.x = torch.stack(xs, dim=0)

        edge_indices = []
        edge_attrs = []
        for bond in mol.GetBonds():
            edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]]
            edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]]

            bond_type = bond.GetBondType()
            single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
            double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
            triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
            aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
            conjugation = 1. if bond.GetIsConjugated() else 0.
            ring = 1. if bond.IsInRing() else 0.
            stereo = [0.] * 4
            stereo[self.stereos.index(bond.GetStereo())] = 1.

            edge_attr = torch.tensor(
                [single, double, triple, aromatic, conjugation, ring] + stereo)

            edge_attrs += [edge_attr, edge_attr]

        if len(edge_attrs) == 0:
            data.edge_index = torch.zeros((2, 0), dtype=torch.long)
            data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
        else:
            data.edge_index = torch.tensor(edge_indices).t().contiguous()
            data.edge_attr = torch.stack(edge_attrs, dim=0)

        return data


x_map = {
    'atomic_num':
    list(range(0, 119)),
    'chirality': [
        'CHI_UNSPECIFIED',
        'CHI_TETRAHEDRAL_CW',
        'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER',
    ],
    'degree':
    list(range(0, 11)),
    'formal_charge':
    list(range(-5, 7)),
    'num_hs':
    list(range(0, 9)),
    'num_radical_electrons':
    list(range(0, 5)),
    'hybridization': [
        'UNSPECIFIED',
        'S',
        'SP',
        'SP2',
        'SP3',
        'SP3D',
        'SP3D2',
        'OTHER',
    ],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}

e_map = {
    'bond_type': [
        'misc',
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'AROMATIC',
    ],
    'stereo': [
        'STEREONONE',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
        'STEREOANY',
    ],
    'is_conjugated': [False, True],
}


import torch
from torch_geometric.data import (InMemoryDataset, Data)
import re

class Molecule(InMemoryDataset):
    r"""The `MoleculeNet <http://moleculenet.ai/datasets-1>`_ benchmark
    collection  from the `"MoleculeNet: A Benchmark for Molecular Machine
    Learning" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets
    from physical chemistry, biophysics and physiology.
    All datasets come with the additional node and edge features introduced by
    the `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_.
    Args:
        root_dir (string): Root directory.
        name (string): The name of dataset (csv format)
        smi_idx (integer): index of smiles column
        target_idx (integer): index of target column
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """



    def __init__(self, root_dir, name, smi_idx, target_idx, transform=None, pre_transform=None,
                 pre_filter=None):
        self.root_dir = root_dir
        self.name = name
        self.smi_idx = smi_idx
        self.target_idx = target_idx
        #skip calling data
        super(Molecule, self).__init__(None, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return osp.join(self.root_dir, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root_dir,'processed')

    @property
    def raw_file_names(self):
        return f'{self.name}'

    @property
    def processed_file_names(self):
        return 'data.pt'


    def process(self):
        from rdkit import Chem
        with open(self.raw_file_names, 'r') as f:
            dataset = f.read().split('\n')[1:-1]
            dataset = [x for x in dataset if len(x) > 0]  # Filter empty lines.

        data_list = []
        for line in dataset:
            line = re.sub(r'\".*\"', '', line)  # Replace ".*" strings.
            line = line.split(',')

            smiles = line[self.smi_idx]
            ys = line[self.target_idx]
            ys = ys if isinstance(ys, list) else [ys]

            ys = [float(y) if len(y) > 0 else float('NaN') for y in ys]
            y = torch.tensor(ys, dtype=torch.float).view(1, -1)

            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue

            xs = []
            for atom in mol.GetAtoms():
                x = []
                x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
                x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
                x.append(x_map['degree'].index(atom.GetTotalDegree()))
                x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
                x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
                x.append(x_map['num_radical_electrons'].index(
                    atom.GetNumRadicalElectrons()))
                x.append(x_map['hybridization'].index(
                    str(atom.GetHybridization())))
                x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
                x.append(x_map['is_in_ring'].index(atom.IsInRing()))
                xs.append(x)

            x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

            edge_indices, edge_attrs = [], []
            for bond in mol.GetBonds():
                i = bond.GetBeginAtomIdx()
                j = bond.GetEndAtomIdx()

                e = []
                e.append(e_map['bond_type'].index(str(bond.GetBondType())))
                e.append(e_map['stereo'].index(str(bond.GetStereo())))
                e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

                edge_indices += [[i, j], [j, i]]
                edge_attrs += [e, e]

            edge_index = torch.tensor(edge_indices)
            edge_index = edge_index.t().to(torch.long).view(2, -1)
            edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

            # Sort indices.
            if edge_index.numel() > 0:
                perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
                edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y,
                        smiles=smiles)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)

        torch.save(self.collate(data_list), self.processed_paths[0])

    def __repr__(self):
        return '{}({})'.format(self.names[self.name][0], len(self))


The definition of AttentiveFP is not required because PyG has already defined model, it’s easy to use the model just import it ;)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AttentiveFP(in_channels=39, hidden_channels=200, out_channels=1,
                    edge_dim=10, num_layers=2, num_timesteps=2,
                    dropout=0.2).to(device)
print(model)
>
AttentiveFP(
  (lin1): Linear(in_features=39, out_features=200, bias=True)
  (atom_convs): ModuleList(
    (0): GATEConv(
      (lin1): Linear(in_features=210, out_features=200, bias=False)
      (lin2): Linear(in_features=200, out_features=200, bias=False)
    )
    (1): GATConv(200, 200, heads=1)
  )
  (atom_grus): ModuleList(
    (0): GRUCell(200, 200)
    (1): GRUCell(200, 200)
  )
  (mol_conv): GATConv(200, 200, heads=1)
  (mol_gru): GRUCell(200, 200)
  (lin2): Linear(in_features=200, out_features=1, bias=True)
)

Then load data and split it to train/test/val.

dataset = Molecule(root_dir='/home/iwatobipen/dev/data/AFP_Mol/esol/testf',
                  name='/home/iwatobipen/dev/data/AFP_Mol/esol/testf/delaney-processed.csv',
                  smi_idx=-1,
                  target_idx=-2,
                  pre_transform=GenFeatures()).shuffle()

N = len(dataset) // 10
val_dataset = dataset[:N]
test_dataset = dataset[N:2 * N]
train_dataset = dataset[2 * N:]

train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=200)
test_loader = DataLoader(test_dataset, batch_size=200)

Now almost there, let’s train the model.

optimizer = torch.optim.Adam(model.parameters(), lr=10**-2.5,
                             weight_decay=10**-5)

def train():
    total_loss = total_examples = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = F.mse_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
    return sqrt(total_loss / total_examples)

@torch.no_grad()
def test(loader):
    mse = []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        mse.append(F.mse_loss(out, data.y, reduction='none').cpu())
    return float(torch.cat(mse, dim=0).mean().sqrt())

for epoch in range(1, 20):
    train_rmse = train()
    val_rmse = test(val_loader)
    test_rmse = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '
          f'Test: {test_rmse:.4f}')
>
Epoch: 001, Loss: 3.4306 Val: 2.7534 Test: 2.4873
Epoch: 002, Loss: 2.4330 Val: 2.2135 Test: 1.9825
Epoch: 003, Loss: 1.7889 Val: 2.0010 Test: 2.0716
Epoch: 004, Loss: 1.7881 Val: 1.8208 Test: 1.8639
Epoch: 005, Loss: 1.7114 Val: 1.7611 Test: 1.7465
Epoch: 006, Loss: 1.6452 Val: 1.7461 Test: 1.7135
Epoch: 007, Loss: 1.5740 Val: 1.6480 Test: 1.6811
Epoch: 008, Loss: 1.5113 Val: 1.4189 Test: 1.4966
Epoch: 009, Loss: 1.3412 Val: 1.2268 Test: 1.3886
Epoch: 010, Loss: 1.2381 Val: 1.1057 Test: 1.2172
Epoch: 011, Loss: 1.1652 Val: 1.0242 Test: 1.0864
Epoch: 012, Loss: 1.1103 Val: 1.0130 Test: 1.0764
Epoch: 013, Loss: 1.0821 Val: 0.9757 Test: 0.9694
Epoch: 014, Loss: 1.0460 Val: 1.0927 Test: 1.0020
Epoch: 015, Loss: 1.0466 Val: 1.0505 Test: 0.9732
Epoch: 016, Loss: 1.0152 Val: 0.9447 Test: 0.9786
Epoch: 017, Loss: 0.9574 Val: 0.8512 Test: 0.9221
Epoch: 018, Loss: 0.9336 Val: 0.7625 Test: 0.8277
Epoch: 019, Loss: 0.9327 Val: 0.8616 Test: 0.9142

Now training is finished so try to use model for prediction. To do it, I defined process_mol function which process smiles to graph data.

def process_mol(smiles_list, pre_transform, pre_filter=None):
    from rdkit import Chem
    data_list = []
    for smi in smiles_list:
        smi = re.sub(r'\".*\"', '', smi)  # Replace ".*" strings.
        smiles = smi.split(',')[0]
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue
        xs = []
        for atom in mol.GetAtoms():
            x = []
            x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
            x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
            x.append(x_map['degree'].index(atom.GetTotalDegree()))
            x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
            x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
            x.append(x_map['num_radical_electrons'].index(
                atom.GetNumRadicalElectrons()))
            x.append(x_map['hybridization'].index(
                str(atom.GetHybridization())))
            x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
            x.append(x_map['is_in_ring'].index(atom.IsInRing()))
            xs.append(x)
        x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
        edge_indices, edge_attrs = [], []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            e = []
            e.append(e_map['bond_type'].index(str(bond.GetBondType())))
            e.append(e_map['stereo'].index(str(bond.GetStereo())))
            e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))
            edge_indices += [[i, j], [j, i]]
            edge_attrs += [e, e]
        edge_index = torch.tensor(edge_indices)
        edge_index = edge_index.t().to(torch.long).view(2, -1)
        edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)
        # Sort indices.
        if edge_index.numel() > 0:
            perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
            edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=None,
                    smiles=smiles)
        if pre_filter is not None and not pre_filter(data):
            continue
        if pre_transform is not None:
            data = pre_transform(data)
        data_list.append(data)
    return data_list

Let’s make dataset and predict their properties.

dataset = process_mol(['CCCC', 'OCCO', 'c1ccccc1'], pre_transform=GenFeatures())

print(model(dataset[0].x, dataset[0].edge_index, dataset[0].edge_attr, torch.tensor([0])))
print(model(dataset[1].x, dataset[1].edge_index, dataset[1].edge_attr, torch.tensor([0])))
print(model(dataset[2].x, dataset[2].edge_index, dataset[2].edge_attr, torch.tensor([0])))
>

tensor([[-1.6434]], grad_fn=<AddmmBackward>)
tensor([[0.0233]], grad_fn=<AddmmBackward>)
tensor([[-1.5886]], grad_fn=<AddmmBackward>

The model predicts ethylenglycole is soluble and benzene isn’t soluble. It seems reasonable.

In summary recent version of PyG seems more chemoinformatics friendly but there are no native function to read molecules and convert them to graph object.

I would like to wrote helper function of PyG for chemoinformatics.

Advertisement

Published by iwatobipen

I'm medicinal chemist in mid size of pharmaceutical company. I love chemoinfo, cording, organic synthesis, my family.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.

%d bloggers like this: