In the chemoinformatics area, QSAR by using molecular graph as input is very hot topic. Examples of major implementations are deepchem and chainer-chemistry I think. I also have interest about Graph based QSAR model building. Recently I am using pytorch for my task of deeplearning so I would like to build model with pytorch. Fortunately very elegant package is provided for pytorch named ‘pytorch_geometric‘. PyTorch Geometric is a geometric deep learning extension library for PyTorch.
Today I tried to build GCN model with the package.
At first I defined function of mol to graph which convert molecule to graph vector. Most part of the code borrowed from DeepChem.
Data structure of torch_geometry is described in this URL. I defined molecular graph as undirected graph.
from __future__ import division from __future__ import unicode_literals import numpy as np from rdkit import Chem import multiprocessing import logging import torch from torch_geometric.data import Data # following code was borrowed from deepchem # https://raw.githubusercontent.com/deepchem/deepchem/master/deepchem/feat/graph_features.py def one_of_k_encoding(x, allowable_set): if x not in allowable_set: raise Exception("input {0} not in allowable set{1}:".format( x, allowable_set)) return list(map(lambda s: x == s, allowable_set)) def one_of_k_encoding_unk(x, allowable_set): """Maps inputs not in the allowable set to the last element.""" if x not in allowable_set: x = allowable_set[-1] return list(map(lambda s: x == s, allowable_set)) def get_intervals(l): """For list of lists, gets the cumulative products of the lengths""" intervals = len(l) * [0] # Initalize with 1 intervals[0] = 1 for k in range(1, len(l)): intervals[k] = (len(l[k]) + 1) * intervals[k - 1] return intervals def safe_index(l, e): """Gets the index of e in l, providing an index of len(l) if not found""" try: return l.index(e) except: return len(l) possible_atom_list = [ 'C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br', 'Fe', 'Ca', 'Cu', 'Mc', 'Pd', 'Pb', 'K', 'I', 'Al', 'Ni', 'Mn' ] possible_numH_list = [0, 1, 2, 3, 4] possible_valence_list = [0, 1, 2, 3, 4, 5, 6] possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3] possible_hybridization_list = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2 ] possible_number_radical_e_list = [0, 1, 2] possible_chirality_list = ['R', 'S'] reference_lists = [ possible_atom_list, possible_numH_list, possible_valence_list, possible_formal_charge_list, possible_number_radical_e_list, possible_hybridization_list, possible_chirality_list ] intervals = get_intervals(reference_lists) def get_feature_list(atom): features = 6 * [0] features[0] = safe_index(possible_atom_list, atom.GetSymbol()) features[1] = safe_index(possible_numH_list, atom.GetTotalNumHs()) features[2] = safe_index(possible_valence_list, atom.GetImplicitValence()) features[3] = safe_index(possible_formal_charge_list, atom.GetFormalCharge()) features[4] = safe_index(possible_number_radical_e_list, atom.GetNumRadicalElectrons()) features[5] = safe_index(possible_hybridization_list, atom.GetHybridization()) return features def features_to_id(features, intervals): """Convert list of features into index using spacings provided in intervals""" id = 0 for k in range(len(intervals)): id += features[k] * intervals[k] # Allow 0 index to correspond to null molecule 1 id = id + 1 return id def id_to_features(id, intervals): features = 6 * [0] # Correct for null id -= 1 for k in range(0, 6 - 1): # print(6-k-1, id) features[6 - k - 1] = id // intervals[6 - k - 1] id -= features[6 - k - 1] * intervals[6 - k - 1] # Correct for last one features[0] = id return features def atom_to_id(atom): """Return a unique id corresponding to the atom type""" features = get_feature_list(atom) return features_to_id(features, intervals) def atom_features(atom, bool_id_feat=False, explicit_H=False, use_chirality=False): if bool_id_feat: return np.array([atom_to_id(atom)]) else: from rdkit import Chem results = one_of_k_encoding_unk( atom.GetSymbol(), [ 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', # H? 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown' ]) + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \ one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \ [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \ one_of_k_encoding_unk(atom.GetHybridization(), [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType. SP3D, Chem.rdchem.HybridizationType.SP3D2 ]) + [atom.GetIsAromatic()] # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs` if not explicit_H: results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) if use_chirality: try: results = results + one_of_k_encoding_unk( atom.GetProp('_CIPCode'), ['R', 'S']) + [atom.HasProp('_ChiralityPossible')] except: results = results + [False, False ] + [atom.HasProp('_ChiralityPossible')] return np.array(results) def bond_features(bond, use_chirality=False): from rdkit import Chem bt = bond.GetBondType() bond_feats = [ bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.GetIsConjugated(), bond.IsInRing() ] if use_chirality: bond_feats = bond_feats + one_of_k_encoding_unk( str(bond.GetStereo()), ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]) return np.array(bond_feats) ################# # pen added ################# def get_bond_pair(mol): bonds = mol.GetBonds() res = [[],[]] for bond in bonds: res[0] += [bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] res[1] += [bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()] return res def mol2vec(mol): atoms = mol.GetAtoms() bonds = mol.GetBonds() node_f= [atom_features(atom) for atom in atoms] edge_index = get_bond_pair(mol) edge_attr = [bond_features(bond, use_chirality=False) for bond in bonds] for bond in bonds: edge_attr.append(bond_features(bond)) data = Data(x=torch.tensor(node_f, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long), edge_attr=torch.tensor(edge_attr,dtype=torch.float) ) return data
Now finished define mol2graph. I tried to make GCN class. Following code is written on ipython-notebook. So I call some method for interactive visualization. At first load packages and load test data. I used solubility dataset which is provided from RDKit.
%matplotlib inline import matplotlib.pyplot as plt from rdkit import Chem from rdkit.Chem import AllChem import numpy as np import torch import torch.nn.functional as F from torch.nn import Linear from torch.nn import BatchNorm1d from torch.utils.data import Dataset from torch_geometric.nn import GCNConv from torch_geometric.nn import ChebConv from torch_geometric.nn import global_add_pool, global_mean_pool from torch_geometric.data import DataLoader from torch_scatter import scatter_mean import mol2graph from rdkit.Chem.Draw import IPythonConsole from rdkit.Chem import Draw plt.style.use("ggplot") train_mols = [m for m in Chem.SDMolSupplier('solubility.train.sdf')] test_mols = [m for m in Chem.SDMolSupplier('solubility.test.sdf')] sol_cls_dict = {'(A) low':0, '(B) medium':1, '(C) high':2}
Then convert molecule to graph with the defined function and added label for training. Defined data loader next.
train_X = [mol2graph.mol2vec(m) for m in train_mols] for i, data in enumerate(train_X): y = sol_cls_dict[train_mols[i].GetProp('SOL_classification')] data.y = torch.tensor([y], dtype=torch.long) test_X = [mol2graph.mol2vec(m) for m in test_mols] for i, data in enumerate(test_X): y = sol_cls_dict[test_mols[i].GetProp('SOL_classification')] data.y = torch.tensor([y], dtype=torch.long) train_loader = DataLoader(train_X, batch_size=64, shuffle=True, drop_last=True) test_loader = DataLoader(test_X, batch_size=64, shuffle=True, drop_last=True)
Then I defined model architecture for GCN. The implementation is very deferent for original article.
n_features = 75 # definenet class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = GCNConv(n_features, 128, cached=False) # if you defined cache=True, the shape of batch must be same! self.bn1 = BatchNorm1d(128) self.conv2 = GCNConv(128, 64, cached=False) self.bn2 = BatchNorm1d(64) self.fc1 = Linear(64, 64) self.bn3 = BatchNorm1d(64) self.fc2 = Linear(64, 64) self.fc3 = Linear(64, 3) def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = self.bn1(x) x = F.relu(self.conv2(x, edge_index)) x = self.bn2(x) x = global_add_pool(x, data.batch) x = F.relu(self.fc1(x)) x = self.bn3(x) x = F.relu(self.fc2(x)) x = F.dropout(x, p=0.2, training=self.training) x = self.fc3(x) x = F.log_softmax(x, dim=1) return x
After defined model, I tried to train the model and evaluate the performance.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(epoch): model.train() loss_all = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, data.y) loss.backward() loss_all += loss.item() * data.num_graphs optimizer.step() return loss_all / len(train_X) def test(loader): model.eval() correct = 0 for data in loader: data = data.to(device) output = model(data) pred = output.max(dim=1)[1] correct += pred.eq(data.y).sum().item() return correct / len(loader.dataset) hist = {"loss":[], "acc":[], "test_acc":[]} for epoch in range(1, 101): train_loss = train(epoch) train_acc = test(train_loader) test_acc = test(test_loader) hist["loss"].append(train_loss) hist["acc"].append(train_acc) hist["test_acc"].append(test_acc) print(f'Epoch: {epoch}, Train loss: {train_loss:.3}, Train_acc: {train_acc:.3}, Test_acc: {test_acc:.3}') ax = plt.subplot(1,1,1) ax.plot([e for e in range(1,101)], hist["loss"], label="train_loss") ax.plot([e for e in range(1,101)], hist["acc"], label="train_acc") ax.plot([e for e in range(1,101)], hist["test_acc"], label="test_acc") plt.xlabel("epoch") ax.legend()
After training, I could get following image.

It seems not so bad. Pytorch_geometry has many algorithms for graph based deep learning.
I think it is very cool and useful package for chemoinformatian. ;)
All code of the post is uploaded the URL below.
https://nbviewer.jupyter.org/github/iwatobipen/playground/blob/master/gcn.ipynb
is test_acc mislabelled as test_loss?
Thank you for your comment.
Nice catch! I’ll fix the bug.
Hi, Do you have an example with edge_attr ?
BR
Guillaume
Hi,
I’m sorry I didn’t have public examples now.