AMES classification with WL graph kernel #RDKit

I often feel it difficult for me to implement algorithm from zero-base… I need to more practice. ;-)
BTW, recently I can find many articles about application of graph theory for chemoinformatics.
I found some interesting articles and they provides useful packages in github!

Today, I tried a library named Grakel.
You can find original article from the URL below.
https://arxiv.org/abs/1806.02193

I used the package and compared the performance to traditional SVC. Open AMES dataset is used for following test.
My code is below. The Grakel package has many algorithms and easy to use for calculation of graph kernel. I calculated WL graph kernel with Adjacency matrix from RDKit and built predictive model. At the same time, I built tradicional SVC model with ECFP4(Morgan Finger print radi=2).

To compare the results, it is interesting for me that WL graph kernel worked well even if the kernel does not have details for the molecules such as charge, num of hydrogen etc.
Is it means that Graph based model is powerful? This is only one experience for the descriptor.
I would like to try any other dataset.

These model is based on feature of ligand and not include protein information. For the real world, drug discovery process is needed many informations not only ligands, but also proteins.

I would like to know possibility of graph based approach for chemoinformatics.

from grakel import GraphKernel
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs
import numpy as np
import argparse
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split


def getparser():
    parser = argparse.ArgumentParser('argparser')
    parser.add_argument('input', help='file path and name of input')
    parser.add_argument('prop', help='properties for predict')
    return parser.parse_args()

def molg_from_smi(smiles):
    mol = Chem.MolFromSmiles(smiles)
    atom_with_idx = { i:atom.GetSymbol() for i, atom in enumerate(mol.GetAtoms())}
    adj_m = Chem.GetAdjacencyMatrix(mol, useBO=True).tolist()
    return [adj_m, atom_with_idx]

def molg_from_rdkit(mol):
    atom_with_idx = { i:atom.GetSymbol() for i, atom in enumerate(mol.GetAtoms())}
    adj_m = Chem.GetAdjacencyMatrix(mol, useBO=True).tolist()
    return [adj_m, atom_with_idx]

def mol2fp(mol):
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2)
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    return arr


if __name__=='__main__':
    args = getparser()
    mols = [mol for mol in Chem.SDMolSupplier(args.input) if mol != None]
    X = [molg_from_rdkit(mol) for mol in mols]
    Ames_dict = {'mutagen':1, 'nonmutagen':0}
    Y = [ Ames_dict[mol.GetProp('Ames test categorisation')] for mol in mols]
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y)

    gk = GraphKernel(kernel=[{"name": "weisfeiler_lehman", "niter": 5},{"name":"subtree_wl"}], normalize=True)
    K_train = gk.fit_transform(X_train)
    K_test = gk.transform(X_test)

    gclf = SVC(kernel='precomputed')
    gclf.fit(K_train, Y_train)
    y_pred_g = gclf.predict(K_test)

    from sklearn.metrics import classification_report
    from sklearn.metrics import confusion_matrix
    rep = classification_report(Y_test, y_pred_g)
    print('WL graph kernel')
    print(confusion_matrix(Y_test, y_pred_g))
    print(rep)

    print('\n')
    print('ECFP4')

    X = [mol2fp(mol) for mol in mols]
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y)
    clf = SVC(C=20.)
    clf.fit(X_train, Y_train)
    y_pred = clf.predict(X_test)
    rep = classification_report(Y_test, y_pred)
    print(confusion_matrix(Y_test, y_pred))
    print(rep)

WL graph kernel
[[381 125]
[ 85 494]]
precision recall f1-score support

      0       0.82      0.75      0.78       506
      1       0.80      0.85      0.82       579

avg / total 0.81 0.81 0.81 1085

ECFP4
[[381 110]
[111 483]]
precision recall f1-score support

      0       0.77      0.78      0.78       491
      1       0.81      0.81      0.81       594

avg / total 0.80 0.80 0.80 1085

real 0m40.446s
user 1m49.922s
sys 0m3.074s




		
Advertisements