Update README.md
Browse files
README.md
CHANGED
@@ -74,10 +74,57 @@ At the end of each epoch, the model saves checkpoints of all components, enablin
|
|
74 |
|
75 |
To use this model, ensure you have the necessary libraries installed, including `torch`, `transformers`, `datasets`, and `argparse`. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:
|
76 |
|
|
|
77 |
```bash
|
|
|
78 |
python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"
|
79 |
```
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.
|
82 |
|
83 |
### Model Hyperparameters
|
|
|
74 |
|
75 |
To use this model, ensure you have the necessary libraries installed, including `torch`, `transformers`, `datasets`, and `argparse`. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:
|
76 |
|
77 |
+
# To Train Language Model
|
78 |
```bash
|
79 |
+
|
80 |
python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"
|
81 |
```
|
82 |
|
83 |
+
# To Train World Model
|
84 |
+
```bash
|
85 |
+
|
86 |
+
python lightbulb_WM.py --model_name 'gpt2' --dataset_name 'wikitext' --dataset_config 'wikitext-2-raw-v1' --batch_size 2 --num_epochs 3 --max_length 128 --learning_rate 1e-4 --save_dir './models' --transformer_model_path 'path/to/transformer/model'
|
87 |
+
```
|
88 |
+
|
89 |
+
# Language Model Args:
|
90 |
+
|
91 |
+
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
|
92 |
+
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
|
93 |
+
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
94 |
+
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
|
95 |
+
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
|
96 |
+
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
|
97 |
+
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
|
98 |
+
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
99 |
+
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
|
100 |
+
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
|
101 |
+
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
|
102 |
+
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
|
103 |
+
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
|
104 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
|
105 |
+
|
106 |
+
# World Model Args:
|
107 |
+
|
108 |
+
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
|
109 |
+
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
|
110 |
+
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
111 |
+
parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
|
112 |
+
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
|
113 |
+
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
|
114 |
+
parser.add_argument('--mcts_iterations', type=int, default=5, help='Number of MCTS Iterations')
|
115 |
+
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Learning rate')
|
116 |
+
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
|
117 |
+
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
118 |
+
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
|
119 |
+
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
|
120 |
+
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
|
121 |
+
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
|
122 |
+
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
|
123 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
|
124 |
+
parser.add_argument('--transformer_model_path', type=str, required=True, help='Path to the saved Transformer model')
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.
|
129 |
|
130 |
### Model Hyperparameters
|