mol encoder with Pytorch

Variable Auto Encoder (VAE) is unique method that is used for learning latent representations. VAE encodes discriminative vector to continuous vector in latent space. There are lots of examples in github.
In 2016, Alán Aspuru-Guzik reported new de novo design method by using VAE. The approach represents molecules as SMLIES and SMILES strings are converted one-hot vector. The vectors used for deep learning.
And Maxhodak implemented molencoder with Keras! It is nice work and the strong impact is given to the chemoinformatics area I think.
Repository is below.

The code depends on keras 1.1.1 and python 2.7 and some packages. Recently keras version is 2.x and major code supports python 3.x. I would like to test mol VAE in python 3.x environment. Also I am now learning pytorch, so I would like to convert the code from keras based to pytorch based.
I spent several days for coding, but I found that it was pioneer!
Researcher who is Pande’s group published nice work.

The molencoder works on python3.5, 3.6, and CUDA-8.0 environments with pytorch. The code seems almost same as keras-molecules but there are some differences.
Main difference is activation function. The original code uses ReLU but the code uses SeLU for activation.
I have not check the effect of difference of the function.
I wrote interpolation script by using molencoder.
The code is below.

Description about some arguments.
SOUCE; The starting point of interpolation.
TARGET; The goal point of the experiment.
STEPS; How many steps the model will generate SMILES strings from latent space between source and target.
* Following code will run only with GPU machine because I hard-coded “someprop.cuda()” instead of using if function.

import numpy as np
import torch
import argparse
from torch.autograd import Variable
from rdkit import Chem
from molencoder.models import MolEncoder, MolDecoder
from molencoder.models import MolEncoder, MolDecoder
from molencoder.utils import( load_dataset, initialize_weights,ReduceLROnPlateau, save_checkpoint, validate_model)
from molencoder.featurizers import Featurizer, OneHotFeaturizer

SOURCE = 'c1ccccn1'
DEST =  'c1ccccc1'
STEPS = 200
#charset from chembl

charset = [' ', '#', '%', '(', ')', '+', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '@', 'A', 'B', 'C', 'F', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'S', 'T', 'V', 'X', 'Z', '[', '\\', ']', 'a', 'c', 'e', 'g', 'i', 'l', 'n', 'o', 'p', 'r', 's', 't']

def decode_smiles_from_index(vec, charset):
    return "".join(map(lambda x:charset[x],vec)).strip()

def get_arguments():
    parser = argparse.ArgumentParser(description="Interpolate from source to dest in steps")
    parser.add_argument("--source", type=str, default=DEST)
    parser.add_argument("--dest", type=str, default=SOURCE)
    parser.add_argument("--steps", type=int, default=STEPS)
    return parser.parse_args()

def interpolate(source, dest, steps, charset, encoder, decoder):
    source_just = source.ljust(width)
    dest_just = dest.ljust(width)
    onehot = OneHotFeaturizer(charset=charset)
    sourcevec = onehot.featurize(smiles=[source_just])
    destvec = onehot.featurize(smiles=[dest_just])
    source_encoded = Variable(torch.from_numpy(sourcevec).float()).cuda()
    dest_encoded = Variable(torch.from_numpy(destvec).float()).cuda()
    source_x_latent = encoder(source_encoded)
    dest_x_latent = encoder(dest_encoded)
    step = (dest_x_latent-source_x_latent)/float(steps)
    results = []
    for i in range(steps):
        item = source_x_latent + (step*i)
        sampled = np.argmax(decoder(item).cpu().data.numpy(), axis=2)
        decode_smiles = decode_smiles_from_index(sampled[0], charset)
        results.append((i, item, decode_smiles))
    return results

def main():
    args= get_arguments()
    encoder = MolEncoder( c = len(charset))
    decoder = MolDecoder( c = len(charset))
    print( torch.cuda.is_available() )
    encoder = MolEncoder( c = len(charset)).cuda()
    decoder = MolDecoder( c = len(charset)).cuda()
    bestmodel = torch.load("model_best.pth.tar")
    #bestmodel = torch.load("tempcheckpoint.pth.tar")

    results = interpolate( args.source, args.dest, args.steps, charset, encoder, decoder )
    for result in results:
        print(result[0], result[2])

if __name__=="__main__":

Just a note for me.


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s