aapot
commited on
Commit
·
3657027
1
Parent(s):
6db6916
fix
Browse files- EasyLM/data.py +4 -8
EasyLM/data.py
CHANGED
|
@@ -175,13 +175,14 @@ class HuggingfaceDataset(object):
|
|
| 175 |
self._index = 0
|
| 176 |
|
| 177 |
def __iter__(self):
|
|
|
|
|
|
|
| 178 |
chunk_size = self.config.batch_size * self.config.seq_length
|
| 179 |
-
total_tokens = 0
|
| 180 |
while True:
|
| 181 |
token_buffer = []
|
| 182 |
loss_mask_buffer = []
|
| 183 |
-
if not self._eval_dataset:
|
| 184 |
-
self.
|
| 185 |
for index, example in enumerate(self._dataset):
|
| 186 |
self._index = index
|
| 187 |
if not self._eval_dataset and self._dataset_loc > index:
|
|
@@ -217,12 +218,7 @@ class HuggingfaceDataset(object):
|
|
| 217 |
break
|
| 218 |
else:
|
| 219 |
self._dataset_loc = 0
|
| 220 |
-
self._shuffle()
|
| 221 |
self._train_epochs += 1
|
| 222 |
-
print(f"TRAIN {self._train_epochs} EPOCH DONE")
|
| 223 |
-
|
| 224 |
-
def _shuffle(self):
|
| 225 |
-
self._dataset = self._dataset.shuffle(buffer_size=100)
|
| 226 |
|
| 227 |
def get_state_dict(self):
|
| 228 |
return dict(
|
|
|
|
| 175 |
self._index = 0
|
| 176 |
|
| 177 |
def __iter__(self):
|
| 178 |
+
if not self._eval_dataset and self._train_epochs > 0:
|
| 179 |
+
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
| 180 |
chunk_size = self.config.batch_size * self.config.seq_length
|
|
|
|
| 181 |
while True:
|
| 182 |
token_buffer = []
|
| 183 |
loss_mask_buffer = []
|
| 184 |
+
if not self._eval_dataset and self._train_epochs > 0:
|
| 185 |
+
self._dataset.set_epoch(self._train_epochs)
|
| 186 |
for index, example in enumerate(self._dataset):
|
| 187 |
self._index = index
|
| 188 |
if not self._eval_dataset and self._dataset_loc > index:
|
|
|
|
| 218 |
break
|
| 219 |
else:
|
| 220 |
self._dataset_loc = 0
|
|
|
|
| 221 |
self._train_epochs += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
def get_state_dict(self):
|
| 224 |
return dict(
|