ydshieh
commited on
Commit
·
7245cb4
1
Parent(s):
9ca46fa
update debug.py
Browse files
debug.py
CHANGED
|
@@ -298,46 +298,22 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
| 298 |
|
| 299 |
if shuffle:
|
| 300 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
|
|
| 301 |
else:
|
| 302 |
-
s = time.time()
|
| 303 |
-
# batch_idx = jnp.arange(len(dataset))
|
| 304 |
batch_idx = np.arange(len(dataset))
|
| 305 |
-
e = time.time()
|
| 306 |
-
print(f'get permutation indices for the block with jax - time: {e-s}')
|
| 307 |
|
| 308 |
-
s = time.time()
|
| 309 |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
| 310 |
-
e = time.time()
|
| 311 |
-
print(f'skip incomplete batch with jax - time: {e-s}')
|
| 312 |
-
|
| 313 |
-
s = time.time()
|
| 314 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
| 315 |
-
e = time.time()
|
| 316 |
-
print(f'reshape block indices with np - time: {e-s}')
|
| 317 |
|
| 318 |
for idx in batch_idx:
|
| 319 |
-
|
| 320 |
-
print(f'type idx: {type(idx)}')
|
| 321 |
-
|
| 322 |
-
print(f'pixel values type: {type(dataset["pixel_values"])}')
|
| 323 |
-
print(f'pixel values shape: {dataset["pixel_values"].shape}')
|
| 324 |
-
|
| 325 |
s = time.time()
|
| 326 |
batch = dataset[idx]
|
| 327 |
e = time.time()
|
| 328 |
-
print(f'
|
| 329 |
-
|
| 330 |
-
exit(0)
|
| 331 |
-
|
| 332 |
-
s = time.time()
|
| 333 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 334 |
-
e = time.time()
|
| 335 |
-
print(f'convert one batch from np to jax - time: {e-s}')
|
| 336 |
|
| 337 |
-
s = time.time()
|
| 338 |
batch = shard(batch)
|
| 339 |
-
|
| 340 |
-
print(f'shard one batch with jax - time: {e-s}')
|
| 341 |
yield batch
|
| 342 |
|
| 343 |
|
|
@@ -781,9 +757,9 @@ def main():
|
|
| 781 |
if "train" not in dataset:
|
| 782 |
raise ValueError("--do_train requires a train dataset")
|
| 783 |
train_dataset = dataset["train"]
|
|
|
|
| 784 |
# remove problematic examples
|
| 785 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
| 786 |
-
train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
|
| 787 |
if data_args.max_train_samples is not None:
|
| 788 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
| 789 |
train_dataset = train_dataset.map(
|
|
@@ -803,6 +779,7 @@ def main():
|
|
| 803 |
eval_dataset = dataset["validation"]
|
| 804 |
# remove problematic examples
|
| 805 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
|
|
|
| 806 |
if data_args.max_eval_samples is not None:
|
| 807 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
| 808 |
eval_dataset = eval_dataset.map(
|
|
@@ -820,6 +797,7 @@ def main():
|
|
| 820 |
if "test" not in dataset:
|
| 821 |
raise ValueError("--do_predict requires a test dataset")
|
| 822 |
predict_dataset = dataset["test"]
|
|
|
|
| 823 |
# remove problematic examples
|
| 824 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
| 825 |
if data_args.max_predict_samples is not None:
|
|
@@ -840,7 +818,7 @@ def main():
|
|
| 840 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
| 841 |
# data loader separately (in a sequential order).
|
| 842 |
block_size = training_args.block_size
|
| 843 |
-
|
| 844 |
# Store some constant
|
| 845 |
|
| 846 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
|
@@ -874,28 +852,22 @@ def main():
|
|
| 874 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
| 875 |
|
| 876 |
if shuffle:
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
indices = np.random.permutation(len(train_dataset))
|
| 880 |
-
e = time.time()
|
| 881 |
-
print(f'get permutation indices for the whole dataset with jax - time: {e-s}')
|
| 882 |
else:
|
| 883 |
-
indices =
|
| 884 |
|
| 885 |
for idx in range(num_splits):
|
| 886 |
|
| 887 |
start_idx = block_size * idx
|
| 888 |
end_idx = block_size * (idx + 1)
|
| 889 |
|
| 890 |
-
s = time.time()
|
| 891 |
selected_indices = indices[start_idx:end_idx]
|
| 892 |
-
e = time.time()
|
| 893 |
-
print(f'get block indices with jax - time: {e-s}')
|
| 894 |
|
| 895 |
s = time.time()
|
| 896 |
_ds = ds.select(selected_indices)
|
| 897 |
e = time.time()
|
| 898 |
-
print(f'select block
|
| 899 |
|
| 900 |
names = {
|
| 901 |
"train": "train",
|
|
@@ -904,20 +876,19 @@ def main():
|
|
| 904 |
}
|
| 905 |
|
| 906 |
s = time.time()
|
| 907 |
-
_ds =_ds.map(
|
| 908 |
feature_extraction_fn,
|
| 909 |
batched=True,
|
| 910 |
num_proc=data_args.preprocessing_num_workers,
|
| 911 |
remove_columns=[image_column],
|
| 912 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 913 |
features=features,
|
| 914 |
-
|
| 915 |
-
keep_in_memory=False,
|
| 916 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
| 917 |
)
|
| 918 |
_ds = _ds.with_format("numpy")
|
| 919 |
e = time.time()
|
| 920 |
-
print(f'map
|
| 921 |
|
| 922 |
# No need to shuffle here
|
| 923 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|
|
|
|
| 298 |
|
| 299 |
if shuffle:
|
| 300 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
| 301 |
+
batch_idx = np.asarray(batch_idx)
|
| 302 |
else:
|
|
|
|
|
|
|
| 303 |
batch_idx = np.arange(len(dataset))
|
|
|
|
|
|
|
| 304 |
|
|
|
|
| 305 |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
|
|
|
|
|
| 307 |
|
| 308 |
for idx in batch_idx:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
s = time.time()
|
| 310 |
batch = dataset[idx]
|
| 311 |
e = time.time()
|
| 312 |
+
print(f'fetch batch time: {e-s}')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
| 314 |
|
|
|
|
| 315 |
batch = shard(batch)
|
| 316 |
+
|
|
|
|
| 317 |
yield batch
|
| 318 |
|
| 319 |
|
|
|
|
| 757 |
if "train" not in dataset:
|
| 758 |
raise ValueError("--do_train requires a train dataset")
|
| 759 |
train_dataset = dataset["train"]
|
| 760 |
+
train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
|
| 761 |
# remove problematic examples
|
| 762 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
|
|
|
| 763 |
if data_args.max_train_samples is not None:
|
| 764 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
| 765 |
train_dataset = train_dataset.map(
|
|
|
|
| 779 |
eval_dataset = dataset["validation"]
|
| 780 |
# remove problematic examples
|
| 781 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
| 782 |
+
eval_dataset = datasets.concatenate_datasets([eval_dataset] * 205)
|
| 783 |
if data_args.max_eval_samples is not None:
|
| 784 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
| 785 |
eval_dataset = eval_dataset.map(
|
|
|
|
| 797 |
if "test" not in dataset:
|
| 798 |
raise ValueError("--do_predict requires a test dataset")
|
| 799 |
predict_dataset = dataset["test"]
|
| 800 |
+
predict_dataset = datasets.concatenate_datasets([predict_dataset] * 1024)
|
| 801 |
# remove problematic examples
|
| 802 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
| 803 |
if data_args.max_predict_samples is not None:
|
|
|
|
| 818 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
| 819 |
# data loader separately (in a sequential order).
|
| 820 |
block_size = training_args.block_size
|
| 821 |
+
|
| 822 |
# Store some constant
|
| 823 |
|
| 824 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
|
|
|
| 852 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
| 853 |
|
| 854 |
if shuffle:
|
| 855 |
+
indices = jax.random.permutation(rng, len(train_dataset))
|
| 856 |
+
indices = np.asarray(indices)
|
|
|
|
|
|
|
|
|
|
| 857 |
else:
|
| 858 |
+
indices = np.arange(len(ds))
|
| 859 |
|
| 860 |
for idx in range(num_splits):
|
| 861 |
|
| 862 |
start_idx = block_size * idx
|
| 863 |
end_idx = block_size * (idx + 1)
|
| 864 |
|
|
|
|
| 865 |
selected_indices = indices[start_idx:end_idx]
|
|
|
|
|
|
|
| 866 |
|
| 867 |
s = time.time()
|
| 868 |
_ds = ds.select(selected_indices)
|
| 869 |
e = time.time()
|
| 870 |
+
print(f'select block time: {e-s}')
|
| 871 |
|
| 872 |
names = {
|
| 873 |
"train": "train",
|
|
|
|
| 876 |
}
|
| 877 |
|
| 878 |
s = time.time()
|
| 879 |
+
_ds = _ds.map(
|
| 880 |
feature_extraction_fn,
|
| 881 |
batched=True,
|
| 882 |
num_proc=data_args.preprocessing_num_workers,
|
| 883 |
remove_columns=[image_column],
|
| 884 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 885 |
features=features,
|
| 886 |
+
keep_in_memory=keep_in_memory,
|
|
|
|
| 887 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
| 888 |
)
|
| 889 |
_ds = _ds.with_format("numpy")
|
| 890 |
e = time.time()
|
| 891 |
+
print(f'map time: {e-s}')
|
| 892 |
|
| 893 |
# No need to shuffle here
|
| 894 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|