Predict probabilistic distribution with NGBoost #NGBoost #RDKit #QSAR #Chemoinformatics

Recently novel gradient boosting method was published from Andrew Ng group. It is interesting that NGBoost can calculate not only probability but also probabilistic distribution. It is useful for QSAR because we would like to know not only predicted value/class but also uncertainly of the prediction.

Fortunately NGBoost is available from python! It can be installed from conda or pip.

You can find the source code from following url.
https://github.com/stanfordmlgroup/ngboost/tree/9027df3e0594c3cd4bb87da4678e053846d9cf40

As usual, I tried to use NGBoost for chemoinformatics tasks.

I tested NGBoost with solubility dataset which is provided from rdkit installed directory.

OK let’s write code. At first I tried to classfication.

%matplotlib inline
import matplotlib.pyplot as plt
import os
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
import numpy as np

from sklearn.ensemble import AdaBoostClassifier
from sklearn.svm import SVC
from ngboost import NGBClassifier
from ngboost.ngboost import NGBoost
from ngboost.learners import default_tree_learner
from ngboost.distns import Normal, LogNormal
from ngboost.distns import k_categorical
from sklearn.metrics import classification_report
plt.style.use('ggplot')
from rdkit.Chem import RDConfig
train = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.train.sdf')
train_mols = [m for m in Chem.SDMolSupplier(train)]
test = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.test.sdf')
test_mols = [m for m in Chem.SDMolSupplier(test)]
cls = set([m.GetProp('SOL_classification') for m in train_mols])
clsdict = {c:i for i, c in enumerate(cls)}

Converted mol to fingerprint and got dataset for train and test.

def mol2arr(mol, radi=2, nBits=1024):
    arr = np.zeros((0,))
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radi, nBits)
    DataStructs.ConvertToNumpyArray(fp, arr)
    return arr
trainX = np.array([mol2arr(m) for m in train_mols])
trainY = np.array([clsdict[m.GetProp('SOL_classification')] for m in train_mols])
trainY2 = np.array([np.float(m.GetProp('SOL')) for m in train_mols])

testX = np.array([mol2arr(m) for m in test_mols])
testY = np.array([clsdict[m.GetProp('SOL_classification')] for m in test_mols])
testY2 = np.array([np.float(m.GetProp('SOL')) for m in test_mols])

Let’s run classification task!

adbcls = AdaBoostClassifier()
svc = SVC(probability=True)
ngbc = NGBClassifier(Dist=k_categorical(3))
clsset = [adbcls, svc, ngbc]
for clsfier in clsset:
    clsfier.fit(trainX, trainY)
testres = []
for clsfier in clsset:
    res = clsfier.predict(testX)
    testres.append(res)
N=20
plt.clf()
plt.rcParams["font.size"] = 18
from tqdm import tqdm
plt.figure(figsize=(25, 129))
for idx in tqdm(np.arange(testX[:N].shape[0])):
    plt.subplot(N, 3, idx+1)
    plt.vlines(testres[0][idx]+0.05, 0, 1, 'r', label='adaboost')
    plt.vlines(testres[1][idx]+0.1, 0, 1, 'b', label='svc')
    plt.vlines(testres[2][idx]+0.15, 0, 1, 'g', label='ngboostlassifier')
    plt.vlines(testY[idx], 0, 1, "pink", label="ground truth")
    plt.legend(loc="best")
    plt.title(f'mol_{idx}, {testY[idx]}')
plt.tight_layout()
plt.show()

I plotted the some results with vline plot, some compounds predicted different class by each model.

However, over all performance of these models have not so big difference.

Next, try to regression. In the regression task, NGBoost can calculate probability distribution it seems good for chemoinformatician.

ngbr = NGBoost(Dist=Normal, Base=default_tree_learner)
ngbr.fit(trainX, trainY2)
predv = ngbr.predict(testX)
disty = ngbr.pred_dist(testX)
yrange = np.linspace(-8.5, 0.5).reshape(-1, 1)
distval = disty.pdf(np.linspace(-8.5, 0.5).reshape(-1,1)).transpose()
plt.scatter(testY2, predv, c='b', alpha=0.5)
plt.plot(np.linspace(-8,0), np.linspace(-8,0), c='r', alpha=0.5)
plt.xlabel('ground truth')
plt.ylabel('ngboost')

Not so good model,,,,, but I go to forward. I plotted at first 50th test molecules with distribution.

S = 0
E = 50
plt.clf()
plt.rcParams["font.size"] = 18
plt.figure(figsize=(25, 120))
for idx, _ in tqdm(enumerate(testX[S:E])):
    plt.subplot(E-S, 5, idx+1)
    plt.plot(yrange, distval[idx], c='b')
    plt.vlines(predv[idx], 0, max(distval[idx]), 'r', label='NGB')
    plt.vlines(testY2[idx], 0, max(distval[idx]), 'g', label='true label')
    plt.legend(loc="best")
    plt.title(f'mol {idx}')
plt.tight_layout()
plt.show()

NGBoost used normal distribution for the learning. So predicted value is mean of these distributions.

To plot with this way we can check model performance graphically and which molecule showed low predictive accuracy.

Draw molecule is very easy. Ok check molecules.

Above two molecules showed large error and below molecules showed good predictive values. Both compounds has ortho Cl substituents and aromatic features….

From these dataset it is difficult to explain the difference. It’s good starting point of dive in to the deep understanding of the models and molecules.

Also I need to optimize these models.

I uploaded today’s code on gist and github. Thanks,

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
view raw ngboost.ipynb hosted with ❤ by GitHub

https://github.com/iwatobipen/playground/blob/master/ngboost.ipynb

Published by iwatobipen

I'm medicinal chemist in mid size of pharmaceutical company. I love chemoinfo, cording, organic synthesis, my family.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.

%d bloggers like this: