mol encoder with Pytorch

Variable Auto Encoder (VAE) is unique method that is used for learning latent representations. VAE encodes discriminative vector to continuous vector in latent space. There are lots of examples in github.
In 2016, Alán Aspuru-Guzik reported new de novo design method by using VAE. The approach represents molecules as SMLIES and SMILES strings are converted one-hot vector. The vectors used for deep learning.
And Maxhodak implemented molencoder with Keras! It is nice work and the strong impact is given to the chemoinformatics area I think.
Repository is below.

The code depends on keras 1.1.1 and python 2.7 and some packages. Recently keras version is 2.x and major code supports python 3.x. I would like to test mol VAE in python 3.x environment. Also I am now learning pytorch, so I would like to convert the code from keras based to pytorch based.
I spent several days for coding, but I found that it was pioneer!
Researcher who is Pande’s group published nice work.

The molencoder works on python3.5, 3.6, and CUDA-8.0 environments with pytorch. The code seems almost same as keras-molecules but there are some differences.
Main difference is activation function. The original code uses ReLU but the code uses SeLU for activation.
I have not check the effect of difference of the function.
I wrote interpolation script by using molencoder.
The code is below.

Description about some arguments.
SOUCE; The starting point of interpolation.
TARGET; The goal point of the experiment.
STEPS; How many steps the model will generate SMILES strings from latent space between source and target.
* Following code will run only with GPU machine because I hard-coded “someprop.cuda()” instead of using if function.

import numpy as np
import torch
import argparse
from torch.autograd import Variable
from rdkit import Chem
from molencoder.models import MolEncoder, MolDecoder
from molencoder.models import MolEncoder, MolDecoder
from molencoder.utils import( load_dataset, initialize_weights,ReduceLROnPlateau, save_checkpoint, validate_model)
from molencoder.featurizers import Featurizer, OneHotFeaturizer

SOURCE = 'c1ccccn1'
DEST =  'c1ccccc1'
STEPS = 200
#charset from chembl

charset = [' ', '#', '%', '(', ')', '+', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '@', 'A', 'B', 'C', 'F', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'S', 'T', 'V', 'X', 'Z', '[', '\\', ']', 'a', 'c', 'e', 'g', 'i', 'l', 'n', 'o', 'p', 'r', 's', 't']

def decode_smiles_from_index(vec, charset):
    return "".join(map(lambda x:charset[x],vec)).strip()

def get_arguments():
    parser = argparse.ArgumentParser(description="Interpolate from source to dest in steps")
    parser.add_argument("--source", type=str, default=DEST)
    parser.add_argument("--dest", type=str, default=SOURCE)
    parser.add_argument("--steps", type=int, default=STEPS)
    return parser.parse_args()

def interpolate(source, dest, steps, charset, encoder, decoder):
    source_just = source.ljust(width)
    dest_just = dest.ljust(width)
    onehot = OneHotFeaturizer(charset=charset)
    sourcevec = onehot.featurize(smiles=[source_just])
    destvec = onehot.featurize(smiles=[dest_just])
    source_encoded = Variable(torch.from_numpy(sourcevec).float()).cuda()
    dest_encoded = Variable(torch.from_numpy(destvec).float()).cuda()
    source_x_latent = encoder(source_encoded)
    dest_x_latent = encoder(dest_encoded)
    step = (dest_x_latent-source_x_latent)/float(steps)
    results = []
    for i in range(steps):
        item = source_x_latent + (step*i)
        sampled = np.argmax(decoder(item).cpu().data.numpy(), axis=2)
        decode_smiles = decode_smiles_from_index(sampled[0], charset)
        results.append((i, item, decode_smiles))
    return results

def main():
    args= get_arguments()
    encoder = MolEncoder( c = len(charset))
    decoder = MolDecoder( c = len(charset))
    print( torch.cuda.is_available() )
    encoder = MolEncoder( c = len(charset)).cuda()
    decoder = MolDecoder( c = len(charset)).cuda()
    bestmodel = torch.load("model_best.pth.tar")
    #bestmodel = torch.load("tempcheckpoint.pth.tar")

    results = interpolate( args.source, args.dest, args.steps, charset, encoder, decoder )
    for result in results:
        print(result[0], result[2])

if __name__=="__main__":

Just a note for me.


Published by iwatobipen

I'm medicinal chemist in mid size of pharmaceutical company. I love chemoinfo, cording, organic synthesis, my family.

6 thoughts on “mol encoder with Pytorch

  1. Thanks for this great work! I’m eager to try interpolation with this same pytorch port of chemical vae, and was very happy to find your script.

    Were you able to use the pre-trained weights that come with the repo? for some reason I’m getting a dimension mismatch in the weights when I try to load the model. Any chance you wouldn’t mind sharing the weights you used e.g. via a googledrive link? Thank you

      1. Thank you! yes, it’s

        RuntimeError: Error(s) in loading state_dict for MolEncoder:
        size mismatch for dense_1.0.weight: copying a param with shape torch.Size([435, 290]) from checkpoint, the shape in current model is torch.Size([435, 270]).

      2. Hi, I confirmed the error. It depends on length of charset.
        You can see origin of 270 was following code.

        class MolEncoder(nn.Module):

        def __init__(self, i=120, o=292, c=35):
            super(MolEncoder, self).__init__()
            self.i = i
            self.conv_1 = ConvSELU(i, 9, kernel_size=9)
            self.conv_2 = ConvSELU(9, 9, kernel_size=9)
            self.conv_3 = ConvSELU(9, 10, kernel_size=11)
            self.dense_1 = nn.Sequential(nn.Linear((c - 29 + 3) * 10, 435),
            self.lmbd = Lambda(435, o)

        c is length ofcharset. My chaset is shorter than predefined set which was used to pretrained model. Unfortunately I don’t now which charset is used for build the model, so I can’t reuse the model. If you would like to confirm it please make issue on the original repo.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.

%d bloggers like this: