Visualize feature importance with marimo #cheminformatics #RDKit #marimo

I posted new generation of notebook, marimo recently. It is cool and easy to make interactive analysis environment with python.

I’m interested in the package and am thinking how to use in chemoinformatics tasks. In QSAR tasks, chemoinformaticians are often asked the reason of prediction of the model. So XAI (explainable AI) is an attractive area in the field. rkakamilan shared really useful posts about visualize feature importance of ML models in his blog site with code.

Now I’m writing code for visualize ML weights with compound structure and most of code came from his post. There are lots of way to calculate feature importance current code support only one way but I would like to support shap value at least near the feature.

Here is my code named chem_viz (still under development….)

import base64
import functools
import rdkit
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import cm
import matplotlib_inline
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from rdkit.Chem import Draw
from rdkit.Chem import DataStructs
from rdkit.Chem.Draw import SimilarityMaps
from rdkit.Chem.Draw import rdDepictor, rdMolDraw2D

def red_blue_cmap(x):
    """Red to Blue color map
    Args:
        x (float): value between -1 ~ 1, represents normalized saliency score
    Returns (tuple): tuple of 3 float values representing R, G, B.
    """
    if x > 0:
        # Red for positive value
        # x=0 -> 1, 1, 1  (white)
        # x=1 -> 1, 0, 0 (red)
        return 1.0, 1.0-x, 1.0-x
    else:
        # Blue for negative value
        x *= -1
        return 1.0-x, 1.0-x, 1.0

def is_visible(begin, end):
    if begin <= 0 or end <= 0:
        return 0
    elif begin >= 1 or end >= 1:
        return 1
    else:
        return (begin+end) * 0.5
def color_bond(bond, saliency, color_fn):
    begin = saliency[bond.GetBeginAtomIdx()]
    end = saliency[bond.GetEndAtomIdx()]
    return color_fn(is_visible(begin, end))

def label_cat(label):
    return '$+$' if bool(label!=0) else '$\cdot$'

class Mol2Img():
    def __init__(self, mol, atom_colors, bond_colors, molSize=(450, 150), kekulize=True):
        self.mol = mol
        self.atom_colors = atom_colors
        self.bond_colors = bond_colors
        self.molSize = molSize
        self.kekulize = kekulize
        self.mc = Chem.Mol(self.mol.ToBinary())
        if self.kekulize:
            try:
                Chem.Kekulize(self.mc)
            except:
                self.mc = Chem.Mol(self.mol.ToBinary())

    def mol2png(self):
        drawer = rdMolDraw2D.MolDraw2DCairo(self.molSize[0], self.molSize[1])
        self._getDrawingText(drawer)
        return drawer.GetDrawingText()
    
    def mol2svg(self):
        drawer = rdMolDraw2D.MolDraw2DSVG(self.molSize[0], self.molSize[1])
        self._getDrawingText(drawer)
        return drawer.GetDrawingText()
    
    def _getDrawingText(self, drawer):
            dops = drawer.drawOptions()
            dops.useBWAtomPalette()
            dops.padding = .2
            dops.addAtomIndices = True
            drawer.DrawMolecule(
                self.mc,
                highlightAtoms=[i for i in range(len(self.atom_colors))], 
                highlightAtomColors=self.atom_colors, 
                highlightBonds=[i for i in range(len(self.bond_colors))],
                highlightBondColors=self.bond_colors,
                highlightAtomRadii={i: .5 for i in range(len(self.atom_colors))}
                )
            drawer.FinishDrawing()
            
class XMol():
    def __init__(self, mol, weight_fn, weights=None, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br'], drawingfmt='svg'):
        self.mol = mol
        self.weight_fn = weight_fn
        self.weights = weights
        self.atoms = atoms
        self.drawingfmt = drawingfmt

    def make_explainable_image(self, kekulize=True, molSize=(450, 150)):
        symbols = [f'{self.mol.GetAtomWithIdx(i).GetSymbol()}_{i}' for i in range(self.mol.GetNumAtoms())]
        #df = pd.DataFrame(columns=self.atoms)
        if self.weights is None:
            contribs = self.weight_fn(self.mol)
        else:
            contribs = self.weights
        num_atoms = self.mol.GetNumAtoms()
        arr = np.zeros((num_atoms, len(self.atoms)))
        for i in range(self.mol.GetNumAtoms()):
            _a = self.mol.GetAtomWithIdx(i).GetSymbol()
            arr[i, self.atoms.index(_a)] = contribs[i]
        df = pd.DataFrame(arr, index=symbols, columns=self.atoms)
        self.weights, self.vmax = SimilarityMaps.GetStandardizedWeights(contribs)
        self.vmin = -self.vmax
        atom_colors = {i: red_blue_cmap(e) for i, e in enumerate(self.weights)}
        # bondlist = [bond.GetIdx() for bond in mol.GetBonds()]
        bond_colors = {i: color_bond(bond, self.weights, red_blue_cmap) for i, bond in enumerate(self.mol.GetBonds())}
        viz = Mol2Img(self.mol, atom_colors, bond_colors, molSize=molSize, kekulize=kekulize)
        if self.drawingfmt == 'svg':
            matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
            self.drawingtxt = viz.mol2svg()
        elif self.drawingfmt == 'png':
            self.drawingtext = viz.mol2png()
            matplotlib_inline.backend_inline.set_matplotlib_formats('png')
        else:
            raise Exception("Please select drawingfmt form 'svg' or 'png'")
        self.fig = plt.figure(figsize=(18, 9))
        self.grid = plt.GridSpec(15, 10)
        self.ax = self.fig.add_subplot(self.grid[1:, -1])
        self.ax.barh(range(self.mol.GetNumAtoms()), np.maximum(0, df.values).sum(axis=1), color='C3')
        self.ax.barh(range(self.mol.GetNumAtoms()), np.minimum(0, df.values).sum(axis=1), color='C0')
        self.ax.set_yticks(range(self.mol.GetNumAtoms()))
        self.ax.set_ylim(-.5, self.mol.GetNumAtoms()-0.5)
        symbols= {i: f'${self.mol.GetAtomWithIdx(i).GetSymbol()}_{{{i}}}$' for i in range(self.mol.GetNumAtoms())}
        self.ax.axvline(0, color='k', linestyle='-', linewidth=.5)
        self.ax.spines['top'].set_visible(False)
        self.ax.spines['right'].set_visible(False)
        self.ax.spines['left'].set_visible(False)
        self.ax.tick_params(axis='both', which='both', left=False, labelleft=False)

        self.ax = self.fig.add_subplot(self.grid[1:, :-1], sharey=self.ax)
        self.im = self.ax.imshow(df.values, cmap='bwr', vmin=self.vmin, vmax=self.vmax, aspect='auto')
        self.ax.set_yticks(range(self.mol.GetNumAtoms()))
        self.ax.set_ylim(self.mol.GetNumAtoms() -.5, -.5)
        self.symbols= {i: f'${self.mol.GetAtomWithIdx(i).GetSymbol()}_{{{i}}}$' for i in range(self.mol.GetNumAtoms())}
        self.ax.set_yticklabels(symbols.values())
        self.ax.set_xticks(range(len(self.atoms)))


        self.ax.set_xlim(-.5, len(self.atoms) -.5)
        self.ax.set_xticklabels(self.atoms, rotation=90)
        self.ax.set_ylabel('Node')

        for (j,i),label in np.ndenumerate(df.values):
            self.ax.text(i,j, label_cat(label) ,ha='center',va='center')
        self.ax.tick_params(axis='both', which='both', bottom=True, labelbottom=True, top=False, labeltop=False)
        #ax.grid(c=None)

        self.ax = self.fig.add_subplot(self.grid[0, :-1])
        self.fig.colorbar(mappable=self.im, cax=self.ax, orientation='horizontal')
        self.ax.tick_params(axis='both', which='both', bottom=False, labelbottom=False, top=True, labeltop=True)
        

I added SVG support to original code.

Then let’s use the code from marimo! At first launch marimo notebook with marimo edit command.

import marimo as mo
import chemviz
# sorry this code does not shap now...
import functools
import os
import warnings
warnings.filterwarnings(action='ignore')

from matplotlib import cm
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rdkit
from rdkit import Chem, RDPaths
from rdkit.Chem import AllChem,  DataStructs, Draw, rdBase, rdCoordGen, rdDepictor
from rdkit.Chem.Draw import IPythonConsole, rdMolDraw2D, SimilarityMaps
from rdkit.ML.Descriptors import MoleculeDescriptors
print(f'RDKit: {rdBase.rdkitVersion}')
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import shap
print(f'SHAP: {shap.__version__}')

def mol2fp(mol,radius=2, nBits=1024):
    bitInfo={}
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits, bitInfo=bitInfo)
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    return arr, bitInfo

train_path = os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.train.sdf')
test_path = os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.test.sdf')
train_path='/home/iwatobipen/miniforge3/pkgs/rdkit-2023.09.3-py311h4c2f14b_1/share/RDKit/Docs/Book/data/solubility.train.sdf'
test_path='/home/iwatobipen/miniforge3/pkgs/rdkit-2023.09.3-py311h4c2f14b_1/share/RDKit/Docs/Book/data/solubility.test.sdf'
train_mols = Chem.SDMolSupplier(train_path)
test_mols = Chem.SDMolSupplier(test_path)
print(len(train_mols), len(test_mols))

sol_classes = {'(A) low': 0, '(B) medium': 1, '(C) high': 2}
X_train = np.array([mol2fp(m)[0] for m in train_mols])
y_train = np.array([sol_classes[m.GetProp('SOL_classification')] for m in train_mols], dtype=np.int_) 
X_test = np.array([mol2fp(m)[0] for m in test_mols])
y_test = np.array([sol_classes[m.GetProp('SOL_classification')] for m in test_mols], dtype=np.int_)
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

clf = RandomForestClassifier(random_state=20191215)
clf.fit(X_train, y_train)

def get_proba(fp, proba_fn, class_id):
    return proba_fn((fp,))[0][class_id]

def fp_partial(nBits):
    return functools.partial(SimilarityMaps.GetMorganFingerprint, nBits=nBits)

def show_pred_results(mol, model):
    y_pred = model.predict(mol2fp(mol)[0].reshape((1,-1)))
    sol_dict = {val: key for key, val in sol_classes.items()}
    print(f"True: {mol.GetProp('SOL_classification')} vs Predicted: {sol_dict[y_pred[0]]}")

def makeimg(smi):
    mol = Chem.MolFromSmiles(smi)
    weights = SimilarityMaps.GetAtomicWeightsForModel(mol, 
                                                      fp_partial(1024), 
                                                      lambda x:get_proba(x, clf.predict_proba, 2))
    xmol=chemviz.XMol(mol, weight_fn=None, weights=weights)
    xmol.make_explainable_image()
    return xmol.drawingtxt


In this case I build ML model on the notebook but model can import external source too. And it will be more flexible and readable notebook.

After preparing required data (molecule, fingerprint, helper function, ML model…) I make form for imput and show result field shown below.


form = mo.ui.text(label="input SMILES")
form

mo.Html(f""""{makeimg(form.value)}""")

Marimo provides interactive view so I can get explainable image intractively.

As result, I could make interactive visualization app for QSAR model with very few lines of code. It’s ineresting for me.

BTW, XAI for QSAR gives often difficult to understand results for chemists and they feel or say the images don’t fit current SAR so we can’t agree the results.

I think there is room for improvement in the area. And I would like to opinion about usability of XAI in QSAR from reader who have lots of experience of the field ;)

I updated today’s code on my github repo.

https://github.com/iwatobipen/chem_viz

Published by iwatobipen

I'm medicinal chemist in mid size of pharmaceutical company. I love chemoinfo, cording, organic synthesis, my family.

One thought on “Visualize feature importance with marimo #cheminformatics #RDKit #marimo

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.