minimal-text-diffusion

Steps to train a model on the simple greetings dataset.

This dataset is small to allow faster training, and the test data is simply a fraction of the training data. The data set was generated using few-shot prompting.

Step 1: Tokenization

Specifically, we need to (i) convert the text to a sequence of tokens (integers or IDs) (tokenization) and then (ii) map each token to a continuous vector (embeddings).

Tokenization is an important design choice for training a language generation model. I found word-level tokenization to be the most effective. Still, the implementation in src/utils/custom_tokenizer also includes BPE if you want to experiment (intuitively, BPE trivially increases the dimensionality, so that might be hurting the performance).

Since we are creating vocabulary from scratch, the embeddings for each token are randomly initialized. The embeddings are learned during training.

python src/utils/custom_tokenizer.py train-word-level data/greetings/greetings.txt

Step 2: Training

bash scripts/run_train.sh greetings 1 False False False 5000

Here, the options mean:

Some boolean options may appear redundant, but they allow interesting ablations (e.g., using pre-trained embeddings but not a pre-trained model or freezing pre-trained embeddings).

Step 3: Evaluation

CUDA_VISIBLE_DEVICES=9 && bash scripts/text_sample.sh ckpts/greetings/ema_0.9999_005000.pt 2000 50

Let’s see some random samples from the model (the command cleans the special tokens):

shuf ckpts/greetings/ema_0.9999_005000.pt.samples_50.steps-2000.clamp-no_clamp.txt|head -n 10|cut -f 2 -d '['|cut -f2 -d ']'|sed 's/^\s*//g'

i's that to see 
i's nice are see you right 
i'm let everyone you 
i's stopped a you you! 
i's glad hi see you right 
i's glad to you 
i's that to you you! 
i's glad to you you right 
i've chance to you you! 
i's greetings a you 

Controllable generation

Step 1: Train a classifier on the latents

python -u src/controllable/classifier.py --model_name_or_path ckpts/greetings/ema_0.9999_005000.pt --classifier_num_epochs 50 

Step 2: Run generation!

bash scripts/ctrl_text_sample.sh ckpts/greetings/ema_0.9999_005000.pt 300 50

TODO