lightbulb / README.md
RobbiePasquale's picture
Update README.md
adebc30 verified
|
raw
history blame
16 kB
metadata
license: apache-2.0

Model Card for World Model with MCTS and Transformer Components

Model Overview

This model is a World Model that combines Transformers, Mixture of Experts (MoE) layers, Monte Carlo Tree Search (MCTS), and Proximal Policy Optimization (PPO) to simulate and optimize a state-based environment. Designed for complex tasks involving decision-making and action prediction, this model leverages powerful components to encode, predict, and enhance action sequences.

Key Components

  1. Transformer: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
  2. MCTS: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
  3. PPO Agent: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
  4. Custom Losses: Several custom loss functions are implemented to help guide the model’s learning, including Covariance Regularization, Dynamics Performance Loss, Thought Consistency Loss, and more.

Intended Use

This model is suitable for tasks that require complex decision-making and optimization based on action-state transitions. It can be applied in fields like game development, reinforcement learning environments, and AI simulation tasks where sequential decision-making and policy optimization are essential.

Model Architecture

The model is constructed with several primary components:

  1. Transformer: The transformer has encoder and decoder layers with rotary positional encoding and Mixture of Experts (MoE) to improve generalization and reduce computational cost by routing only parts of the data to certain experts. GELU and SwiGLU activation functions are alternated between the experts.

Multi-Token Prediction with Beam Search

Multi-token prediction in a language model involves generating multiple tokens in sequence, rather than one token at a time. This can improve the fluency and coherence of generated text by allowing the model to "look ahead" and consider multiple possible continuations at each step.

Beam Search is a popular decoding algorithm used for multi-token prediction that allows the model to explore multiple potential sequences and choose the most likely one based on the overall probability. Here's how it works:

  1. Initialization:

    • Start with a single "beam" (sequence) that contains the initial token, typically the beginning-of-sequence (<sos>) token.
  2. Expansion:

    • At each time step, the model generates a probability distribution over the vocabulary for each sequence in the beam.
    • For each sequence, it expands by predicting the next possible tokens, creating new sequences for each possible token.
  3. Scoring:

    • Calculate the score for each expanded sequence by taking the sum (or average) of log probabilities for all tokens in the sequence. Log probabilities are used to avoid underflow and ensure stable computation.
  4. Selection:

    • Keep only the top-k sequences with the highest scores (known as the "beam width" or "beam size") and discard the rest. This limits the number of sequences kept at each step, focusing only on the most promising ones.
  5. Repeat:

    • Continue expanding and scoring until reaching the desired sequence length or the end-of-sequence (<eos>) token.
  6. Final Output:

    • After a set number of steps, or if all sequences end with <eos>, select the sequence with the highest score as the final output.

This process allows the model to generate more fluent and accurate sequences by considering multiple potential continuations at each step and selecting the best overall sequence.


Brief Overview of the Transformer Architecture

The Transformer architecture, introduced in the paper "Attention is All You Need," is a powerful neural network design for handling sequential data, especially in natural language processing tasks. Transformers are known for their parallelism and ability to capture long-range dependencies in data.

Key Components of the Transformer

  1. Embeddings and Positional Encoding:

    • The input tokens are embedded into dense vectors. Since Transformers do not inherently encode the sequence order (as opposed to RNNs), they require positional encodings. These encodings are added to the embeddings to provide information about the token positions in the sequence.
  2. Multi-Head Self-Attention:

    • Each token in a sequence attends to every other token, capturing dependencies regardless of distance. Multiple attention heads allow the model to focus on different parts of the sequence, extracting varied features.
    • In self-attention, the model computes query, key, and value vectors for each token. The output is a weighted sum of values, where the weights are determined by the similarity between the query and key vectors.
  3. Feedforward Neural Networks:

    • After self-attention, a position-wise feedforward neural network is applied to each token independently. This network consists of two linear layers with a ReLU or GELU activation function in between.
  4. Layer Normalization and Residual Connections:

    • To improve learning stability, layer normalization is applied. Residual connections help the model to learn effectively by adding the input of a layer to its output, allowing gradients to flow more easily during backpropagation.
  5. Stacking of Layers:

    • The Transformer consists of multiple encoder and decoder layers. Each encoder layer is identical and consists of self-attention and feedforward layers. The decoder layers include an additional cross-attention mechanism to attend to the encoder's output.
  6. Final Linear and Softmax Layer:

    • The final output of the decoder layer is passed through a linear layer, projecting it onto the vocabulary size. A softmax function then converts the output into a probability distribution over the vocabulary, from which the next token is selected or sampled.

Encoder-Decoder Structure

  • Encoder: The encoder processes the input sequence into a contextualized representation that captures relationships between tokens. It consists of multiple layers of self-attention and feedforward networks.
  • Decoder: The decoder generates the output sequence by attending to both the encoded input representation (using cross-attention) and previously generated tokens (using self-attention). The decoder's output is used to predict the next token in the sequence.
  1. Representation Network: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
  2. Dynamics Network: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
  3. Prediction Network: Predicts both the policy logits and value estimates for a given state. It outputs the probabilities of different actions as well as a single scalar value.
  4. MCTS: This module performs Monte Carlo Tree Search to evaluate the quality of actions over multiple iterations. It expands nodes based on the policy logits from the Prediction Network and simulates the reward by backpropagating value estimates.
  5. PPO Agent: Uses policy and value estimates to calculate PPO loss, which updates the policy while maintaining the constraint on the KL divergence between old and new policies.

The transformer strategically utilises beam search as well as multi token prediction, in order to enrich the encoding from the representation network.

A generated sequence of tokens is an action, for example if a token is t, then an action is:

a_1= {t1,...,tN}

then a policy is a sequence of actions:

P_1 = {a_1,...,aN}

The MCTS and OOPS explores what we are defining as 'thoughts', where a thought is a set of policies:

thought_1 = {P1, ... , PN}

The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.

Training Details

The model is trained with the following components and techniques:

Training Procedure

  • Data Loading: The data is tokenized and prepared with attention to padding and truncation. Text data is grouped into sequences of fixed length for efficient training.
  • Optimization: Training uses an AdamW optimizer with CosineAnnealingLR scheduler for learning rate adjustments. The Gradient Scaler helps prevent overflow when training with mixed precision.
  • Gradient Accumulation: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.
  • Loss Functions: The training process leverages a comprehensive set of custom loss functions:
    • InfoNCE Loss: A contrastive loss to encourage representation similarity between related pairs.
    • Covariance Regularization: Encourages diverse state representations by minimizing co-linearity in embeddings.
    • Dynamics Performance Loss: Combines MSE and variance losses to penalize incorrect state predictions.
    • Thought Consistency Loss: Encourages the model to output consistent states for similar actions.
    • Policy Value Joint Loss: A weighted combination of policy and value loss for the PPO agent.
    • Action Diversity Reward: Rewards diverse action embeddings to avoid mode collapse.
    • Exploration Regularization: Encourages exploration by penalizing high visitation counts.
    • KL Divergence Loss: Keeps the policy update close to the previous policy to stabilize training.

Evaluation

After each epoch, the model is evaluated on the validation set, computing the average loss over the dataset. The evaluation function utilizes the same loss functions as training but does not backpropagate, allowing it to be run in inference mode.

Checkpoints

At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.

Usage

To use this model, ensure you have the necessary libraries installed, including torch, transformers, datasets, and argparse. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:

To Train Language Model


python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"

To Train World Model


python lightbulb_WM.py --model_name 'gpt2' --dataset_name 'wikitext' --dataset_config 'wikitext-2-raw-v1' --batch_size 2 --num_epochs 3 --max_length 128 --learning_rate 1e-4 --save_dir './models'  --transformer_model_path 'path/to/transformer/model'

Language Model Args:

parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')

World Model Args:

parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
parser.add_argument('--mcts_iterations', type=int, default=5, help='Number of MCTS Iterations')
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Learning rate')
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
parser.add_argument('--transformer_model_path', type=str, required=True, help='Path to the saved Transformer model')

This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.

Model Hyperparameters

Here are the main parameters you can set:

  • --model_name: Name of the pretrained model for tokenization.
  • --dataset_name: Hugging Face dataset name.
  • --batch_size: Batch size for training.
  • --num_epochs: Number of epochs to train.
  • --max_length: Max sequence length.
  • --transformer_model_path: Path to the pretrained Transformer model.
  • --learning_rate: Learning rate for optimizer.
  • --save_dir: Directory to save model checkpoints.
  • --temperature, --alpha, --beta, --lambda_reg: Hyperparameters for regularization.

Expected Results

As training proceeds, you should see progressively lower training and evaluation losses. Upon completion, the model can perform complex decision-making tasks by generating sequences of actions with MCTS and PPO optimization.

Requirements

This code requires:

  • Python 3.7+
  • torch>=1.7.1
  • transformers
  • datasets
  • argparse

Limitations

Due to the heavy computational nature of this model, training time may be significant, especially on a CPU. GPU support is recommended for efficient training. Additionally, the MCTS and PPO implementations here are designed for demonstration purposes and may need further tuning for specific use cases.

Citation

If you use this model in your research, please cite the author.


This model card should provide an overview for anyone looking to understand, utilize, or modify your World Model with MCTS and Transformer components.