Assign7: A Smiling Variational Autoencoder
Checkpoint Friday, March 29, 11:59pm Monday, April 1, 11:59pm
Due Monday, April 8 Tuesday, April 9, 11:59pm
SubmitFor this assignment you will use PyTorch to design and train a Variational Autoencoder (VAE) to generate SMILES strings. This will give experience training both generative and recurrent models.
Data
The source data comes from Pubchem.We have filtered all the SMILES in PubChem to include only those with a reduced alphabet of characters (e.g., no "weird" elements like selenium, molecules are Kekulized so there are no aromatic characters) with a length less than 150 characters. Filtering the PubChem SMILES strings with these criteria leaves 62,129,800 valid SMILES strings. We provide these SMILES both in canonical form (here) and (somewhat) randomized (here). The code to randomize smiles can be found here, in case you wish to further augment this data.
Training
We will be using PyTorch to define and train our VAE. Your model will be evaluated using PyTorch version 2.0.0.
To get you started, we are providing you with a skeleton VAE model with a working Dataset class:
train_VAE.py.
You will want to re-write the Dataset class to load preprocessed data (e.g. a saved numpy array) for faster initialization. You may use whatever architecture you think best for the encoder and decoder (e.g. CNN, GRU, Transformer - don't forget the positional encoding), but your decoder must take a 1024 sized vector sampled around a unit Gaussian and generate a SMILES string with max length 150.
Here are some VAE papers that you can use to model your own VAE after: Tutorial on Variational Autoencoders, A Neural Representation of Sketch Drawings, PixelVAE: A Latent Variable Model For Natural Images, and Generating Sentences from a Continuous Space.
Here are repositories/Pytorch tutorials that may be helpful in constructing your own VAE: Pytorch Seq2Seq Tutorial, Pytorch Sentence-VAE, Pytorch Sketch-RNN.
Cluster Etiquette Please read this document paying attention to the "Cluster Etiquette" and "Consideration of Others" sections. You may (and should) use both the CSB and CRC clusters to complete this assignment.
Tips
- You will want to modify the provided code to pre-process the training set for faster startup times (e.g., save it as a npy file).
- You will want to sample output characters in your decoder using torch.multinomial and softmaxed probabilities, not simply take the highest probability. This will generate warnings when writing out the trace, but they can be safely ignored.
- For faster initial experimentation you may want to create a smaller training set (e.g. all strings smaller than 50 characters).
- As part of your initial exploration you might consider training non-variational autoencoders, as these are easier to train and if a model doesn't have the capacity to function as an autoencoder it won't perform well as a VAE.
- Check your GPU utilization, it should be >50%. If not, are you parallelizing your data loading? Have a large enough batch size to keep the GPU busy?
- The full dataset is large so you will want to monitor the performance of your model more frequently than once an epoch.
- Watch for posterior collapse where the KL divergence term of the loss overwhelms the reconstruction loss. This can be addressed by appropriately weighting these terms of the loss function, possibly by adopting a KL cost annealing strategy.
- You are definitely going to need to perform hyperparameter sweeps as well as evaluate different kinds of encoder/decoder architectures.
- If you run into incomprehensible CUDA errors, try running on the cpu instead as this will likely generate more informative error messages.
- Make sure you understand the difference between reshaping a tensor and transposing dimensions.
- For every pytorch operation, double check the documentation and make sure your inputs and outputs match what the documentation requires. Not just the shape, but the contents.
- A good sanity check is to make sure the same input gives the same output (if the same random seed is used) for different batch sizes. Information should not mix between examples within a batch.
- If you have a low loss, your decoder should produce something very similar to the output. Test this! If you have different decoding code paths for training versus evaluation (not recommended), make sure the evaluation code path will produce the same output as the low-loss training code path.
Evaluation
Once you have a trained model you are happy with, you should submit a link to the model to the evaluation server (below). The link must be a world readable Google Storage location (gs://). You can copy your model output to a storage bucket with gsutil:gsutil acl ch -r -u AllUsers:R gs://bucket-name/ gutil cp -r model.pth gs://bucket-name/World readable http:// URLs should also work (will be fetched using wget). The model will be evaluated with 1000 random vectors sampled from the latent space using the procedure in eval.py. Note we sample from a normal distribution with a zero mean and unit variance. Make sure you test your model using the same procedure as in this file. Your model will be evaluated with PyTorch 2.0.0 on an NVIDIA TITAN RTX GPU (24 GB).
We will compute the following metrics (all out of 1000 samples):
UniqueSMI | The number of unique strings generated |
ValidSMI | The number of (not unique) valid SMILES strings generated (can be parsed by RDKit) |
AveRings | The average number of rings in each unique valid molecule (this is a simple measure of chemical complexity) |
UniqueValidMols | The number of valid and unique molecules (distinct SMILES strings that represent the same molecule are only counted once) |
NovelMols | The number of valid and unique molecules that are not in the training set |
ZeroSMI | The molecule generated from the zero latent vector |
Checkpoint
You should have a model that generates at least one valid molecule submitted to the leaderboard by the checkpoint date or you will receive a 10 point penalty.Grading
Your decoder must generate non-trivial molecular structures (AveRings > 0). The key metric for grading is NovelMols. Your grade will be the percent NovelMols plus 15, with no capping at 100. For example, if you generate 912 NovelMols, you grade will be 91.2+15=106.