ydshieh
		
	commited on
		
		
					Commit 
							
							·
						
						afddfdc
	
1
								Parent(s):
							
							9a97c24
								
improve doc
Browse files- run_image_captioning_flax.py +26 -16
    	
        run_image_captioning_flax.py
    CHANGED
    
    | @@ -755,6 +755,9 @@ def main(): | |
| 755 |  | 
| 756 | 
             
                # Setting padding="max_length" as we need fixed length inputs for jitted functions
         | 
| 757 | 
             
                def tokenization_fn(examples, max_target_length):
         | 
|  | |
|  | |
|  | |
| 758 |  | 
| 759 | 
             
                    captions = []
         | 
| 760 | 
             
                    for caption in examples[caption_column]:
         | 
| @@ -784,6 +787,9 @@ def main(): | |
| 784 | 
             
                    return model_inputs
         | 
| 785 |  | 
| 786 | 
             
                def feature_extraction_fn(examples):
         | 
|  | |
|  | |
|  | |
| 787 |  | 
| 788 | 
             
                    images = [Image.open(image_file) for image_file in examples[image_column]]
         | 
| 789 | 
             
                    encoder_inputs = feature_extractor(images=images, return_tensors="np")
         | 
| @@ -792,6 +798,9 @@ def main(): | |
| 792 | 
             
                    return model_inputs
         | 
| 793 |  | 
| 794 | 
             
                def preprocess_fn(examples, max_target_length):
         | 
|  | |
|  | |
|  | |
| 795 |  | 
| 796 | 
             
                    model_inputs = {}
         | 
| 797 | 
             
                    model_inputs.update(tokenization_fn(examples, max_target_length))
         | 
| @@ -817,10 +826,15 @@ def main(): | |
| 817 | 
             
                    }
         | 
| 818 | 
             
                )
         | 
| 819 |  | 
| 820 | 
            -
                 | 
| 821 | 
            -
                 | 
| 822 | 
            -
                 | 
| 823 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 824 |  | 
| 825 | 
             
                if training_args.do_train:
         | 
| 826 | 
             
                    if "train" not in dataset:
         | 
| @@ -837,11 +851,11 @@ def main(): | |
| 837 | 
             
                        # kept image paths
         | 
| 838 | 
             
                        remove_columns=remove_columns_kwarg,
         | 
| 839 | 
             
                        load_from_cache_file=not data_args.overwrite_cache,
         | 
| 840 | 
            -
                        desc=f"Running { | 
| 841 | 
             
                        fn_kwargs={"max_target_length": data_args.max_target_length},
         | 
| 842 | 
             
                        features=features_kwarg,
         | 
| 843 | 
             
                    )
         | 
| 844 | 
            -
                    if  | 
| 845 | 
             
                        train_dataset = train_dataset.with_format("numpy")
         | 
| 846 |  | 
| 847 | 
             
                if training_args.do_eval:
         | 
| @@ -859,11 +873,11 @@ def main(): | |
| 859 | 
             
                        # kept image paths
         | 
| 860 | 
             
                        remove_columns=remove_columns_kwarg,
         | 
| 861 | 
             
                        load_from_cache_file=not data_args.overwrite_cache,
         | 
| 862 | 
            -
                        desc=f"Running { | 
| 863 | 
             
                        fn_kwargs={"max_target_length": data_args.val_max_target_length},
         | 
| 864 | 
             
                        features=features_kwarg,
         | 
| 865 | 
             
                    )
         | 
| 866 | 
            -
                    if  | 
| 867 | 
             
                        eval_dataset = eval_dataset.with_format("numpy")
         | 
| 868 |  | 
| 869 | 
             
                if training_args.do_predict:
         | 
| @@ -881,17 +895,13 @@ def main(): | |
| 881 | 
             
                        # kept image paths
         | 
| 882 | 
             
                        remove_columns=remove_columns_kwarg,
         | 
| 883 | 
             
                        load_from_cache_file=not data_args.overwrite_cache,
         | 
| 884 | 
            -
                        desc=f"Running { | 
| 885 | 
             
                        fn_kwargs={"max_target_length": data_args.val_max_target_length},
         | 
| 886 | 
             
                        features=features_kwarg,
         | 
| 887 | 
             
                    )
         | 
| 888 | 
            -
                    if  | 
| 889 | 
             
                        predict_dataset = predict_dataset.with_format("numpy")
         | 
| 890 |  | 
| 891 | 
            -
                # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
         | 
| 892 | 
            -
                # data loader separately (in a sequential order).
         | 
| 893 | 
            -
                block_size = training_args.block_size
         | 
| 894 | 
            -
                    
         | 
| 895 | 
             
                # Store some constant
         | 
| 896 |  | 
| 897 | 
             
                train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
         | 
| @@ -1187,7 +1197,7 @@ def main(): | |
| 1187 | 
             
                    preds = []
         | 
| 1188 | 
             
                    labels = []
         | 
| 1189 |  | 
| 1190 | 
            -
                    batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
         | 
| 1191 | 
             
                    steps = len(dataset) // eval_batch_size
         | 
| 1192 | 
             
                    for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
         | 
| 1193 | 
             
                        # Model forward
         | 
| @@ -1295,7 +1305,7 @@ def main(): | |
| 1295 |  | 
| 1296 | 
             
                        train_metrics = []
         | 
| 1297 |  | 
| 1298 | 
            -
                        train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, split="train")
         | 
| 1299 |  | 
| 1300 | 
             
                        # train
         | 
| 1301 | 
             
                        for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
         | 
|  | |
| 755 |  | 
| 756 | 
             
                # Setting padding="max_length" as we need fixed length inputs for jitted functions
         | 
| 757 | 
             
                def tokenization_fn(examples, max_target_length):
         | 
| 758 | 
            +
                    """
         | 
| 759 | 
            +
                    Run tokenization on captions.
         | 
| 760 | 
            +
                    """
         | 
| 761 |  | 
| 762 | 
             
                    captions = []
         | 
| 763 | 
             
                    for caption in examples[caption_column]:
         | 
|  | |
| 787 | 
             
                    return model_inputs
         | 
| 788 |  | 
| 789 | 
             
                def feature_extraction_fn(examples):
         | 
| 790 | 
            +
                    """
         | 
| 791 | 
            +
                    Run feature extraction on images
         | 
| 792 | 
            +
                    """
         | 
| 793 |  | 
| 794 | 
             
                    images = [Image.open(image_file) for image_file in examples[image_column]]
         | 
| 795 | 
             
                    encoder_inputs = feature_extractor(images=images, return_tensors="np")
         | 
|  | |
| 798 | 
             
                    return model_inputs
         | 
| 799 |  | 
| 800 | 
             
                def preprocess_fn(examples, max_target_length):
         | 
| 801 | 
            +
                    """
         | 
| 802 | 
            +
                    Run tokenization + image feature extraction
         | 
| 803 | 
            +
                    """
         | 
| 804 |  | 
| 805 | 
             
                    model_inputs = {}
         | 
| 806 | 
             
                    model_inputs.update(tokenization_fn(examples, max_target_length))
         | 
|  | |
| 826 | 
             
                    }
         | 
| 827 | 
             
                )
         | 
| 828 |  | 
| 829 | 
            +
                # If `block_size` is `0`, tokenization & image feature extraction is done before training
         | 
| 830 | 
            +
                run_feat_ext_before_training = training_args.block_size == 0
         | 
| 831 | 
            +
                # Used in .map() below
         | 
| 832 | 
            +
                function_kwarg = preprocess_fn if run_feat_ext_before_training else tokenization_fn
         | 
| 833 | 
            +
                # `features` is used only for the final preprocessed dataset (for the performance purpose).
         | 
| 834 | 
            +
                features_kwarg = features if run_feat_ext_before_training else None
         | 
| 835 | 
            +
                # Keep `image_column` if the feature extraction is done during training
         | 
| 836 | 
            +
                remove_columns_kwarg = [x for x in column_names if x != image_column or run_feat_ext_before_training]
         | 
| 837 | 
            +
                processor_names = "tokenizer and feature extractor" if run_feat_ext_before_training else "tokenizer"
         | 
| 838 |  | 
| 839 | 
             
                if training_args.do_train:
         | 
| 840 | 
             
                    if "train" not in dataset:
         | 
|  | |
| 851 | 
             
                        # kept image paths
         | 
| 852 | 
             
                        remove_columns=remove_columns_kwarg,
         | 
| 853 | 
             
                        load_from_cache_file=not data_args.overwrite_cache,
         | 
| 854 | 
            +
                        desc=f"Running {processor_names} on train dataset",
         | 
| 855 | 
             
                        fn_kwargs={"max_target_length": data_args.max_target_length},
         | 
| 856 | 
             
                        features=features_kwarg,
         | 
| 857 | 
             
                    )
         | 
| 858 | 
            +
                    if run_feat_ext_before_training:
         | 
| 859 | 
             
                        train_dataset = train_dataset.with_format("numpy")
         | 
| 860 |  | 
| 861 | 
             
                if training_args.do_eval:
         | 
|  | |
| 873 | 
             
                        # kept image paths
         | 
| 874 | 
             
                        remove_columns=remove_columns_kwarg,
         | 
| 875 | 
             
                        load_from_cache_file=not data_args.overwrite_cache,
         | 
| 876 | 
            +
                        desc=f"Running {processor_names} on validation dataset",
         | 
| 877 | 
             
                        fn_kwargs={"max_target_length": data_args.val_max_target_length},
         | 
| 878 | 
             
                        features=features_kwarg,
         | 
| 879 | 
             
                    )
         | 
| 880 | 
            +
                    if run_feat_ext_before_training:
         | 
| 881 | 
             
                        eval_dataset = eval_dataset.with_format("numpy")
         | 
| 882 |  | 
| 883 | 
             
                if training_args.do_predict:
         | 
|  | |
| 895 | 
             
                        # kept image paths
         | 
| 896 | 
             
                        remove_columns=remove_columns_kwarg,
         | 
| 897 | 
             
                        load_from_cache_file=not data_args.overwrite_cache,
         | 
| 898 | 
            +
                        desc=f"Running {processor_names} on prediction dataset",
         | 
| 899 | 
             
                        fn_kwargs={"max_target_length": data_args.val_max_target_length},
         | 
| 900 | 
             
                        features=features_kwarg,
         | 
| 901 | 
             
                    )
         | 
| 902 | 
            +
                    if run_feat_ext_before_training:
         | 
| 903 | 
             
                        predict_dataset = predict_dataset.with_format("numpy")
         | 
| 904 |  | 
|  | |
|  | |
|  | |
|  | |
| 905 | 
             
                # Store some constant
         | 
| 906 |  | 
| 907 | 
             
                train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
         | 
|  | |
| 1197 | 
             
                    preds = []
         | 
| 1198 | 
             
                    labels = []
         | 
| 1199 |  | 
| 1200 | 
            +
                    batches = get_batch_iter(rng, dataset, block_size=training_args.block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=split)
         | 
| 1201 | 
             
                    steps = len(dataset) // eval_batch_size
         | 
| 1202 | 
             
                    for _ in tqdm(range(steps), desc=f"{'Predicting' if split == 'test' else 'Evaluating'}...", position=2, leave=False):
         | 
| 1203 | 
             
                        # Model forward
         | 
|  | |
| 1305 |  | 
| 1306 | 
             
                        train_metrics = []
         | 
| 1307 |  | 
| 1308 | 
            +
                        train_batches = get_batch_iter(input_rng, train_dataset, block_size=training_args.block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, split="train")
         | 
| 1309 |  | 
| 1310 | 
             
                        # train
         | 
| 1311 | 
             
                        for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
         | 
