RobbiePasquale commited on
Commit
916381a
·
verified ·
1 Parent(s): 904b97c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +47 -0
README.md CHANGED
@@ -74,10 +74,57 @@ At the end of each epoch, the model saves checkpoints of all components, enablin
74
 
75
  To use this model, ensure you have the necessary libraries installed, including `torch`, `transformers`, `datasets`, and `argparse`. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:
76
 
 
77
  ```bash
 
78
  python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"
79
  ```
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.
82
 
83
  ### Model Hyperparameters
 
74
 
75
  To use this model, ensure you have the necessary libraries installed, including `torch`, `transformers`, `datasets`, and `argparse`. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:
76
 
77
+ # To Train Language Model
78
  ```bash
79
+
80
  python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"
81
  ```
82
 
83
+ # To Train World Model
84
+ ```bash
85
+
86
+ python lightbulb_WM.py --model_name 'gpt2' --dataset_name 'wikitext' --dataset_config 'wikitext-2-raw-v1' --batch_size 2 --num_epochs 3 --max_length 128 --learning_rate 1e-4 --save_dir './models' --transformer_model_path 'path/to/transformer/model'
87
+ ```
88
+
89
+ # Language Model Args:
90
+
91
+ parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
92
+ parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
93
+ parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
94
+ parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
95
+ parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
96
+ parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
97
+ parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
98
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
99
+ parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
100
+ parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
101
+ parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
102
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
103
+ parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
104
+ parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
105
+
106
+ # World Model Args:
107
+
108
+ parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
109
+ parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
110
+ parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
111
+ parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
112
+ parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
113
+ parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
114
+ parser.add_argument('--mcts_iterations', type=int, default=5, help='Number of MCTS Iterations')
115
+ parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Learning rate')
116
+ parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
117
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
118
+ parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
119
+ parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
120
+ parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
121
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
122
+ parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
123
+ parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
124
+ parser.add_argument('--transformer_model_path', type=str, required=True, help='Path to the saved Transformer model')
125
+
126
+
127
+
128
  This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.
129
 
130
  ### Model Hyperparameters