make flash attention optional
Browse files- custom_st.py +19 -9
    	
        custom_st.py
    CHANGED
    
    | @@ -32,15 +32,25 @@ class Transformer(nn.Module): | |
| 32 | 
             
                    self.max_pixels = max_pixels
         | 
| 33 | 
             
                    self.min_pixels = min_pixels
         | 
| 34 |  | 
| 35 | 
            -
                    #  | 
| 36 | 
            -
                     | 
| 37 | 
            -
                         | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 44 |  | 
| 45 | 
             
                    # Initialize processor
         | 
| 46 | 
             
                    self.processor = AutoProcessor.from_pretrained(
         | 
|  | |
| 32 | 
             
                    self.max_pixels = max_pixels
         | 
| 33 | 
             
                    self.min_pixels = min_pixels
         | 
| 34 |  | 
| 35 | 
            +
                    # Try to use flash attention if available, fallback to default attention if not
         | 
| 36 | 
            +
                    try:
         | 
| 37 | 
            +
                        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
         | 
| 38 | 
            +
                            model_name_or_path,
         | 
| 39 | 
            +
                            attn_implementation="flash_attention_2",
         | 
| 40 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 41 | 
            +
                            device_map=device,
         | 
| 42 | 
            +
                            cache_dir=cache_dir,
         | 
| 43 | 
            +
                            **kwargs
         | 
| 44 | 
            +
                        ).eval()
         | 
| 45 | 
            +
                    except (ImportError, ValueError) as e:
         | 
| 46 | 
            +
                        print(f"Flash attention not available, falling back to default attention: {e}")
         | 
| 47 | 
            +
                        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
         | 
| 48 | 
            +
                            model_name_or_path,
         | 
| 49 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 50 | 
            +
                            device_map=device,
         | 
| 51 | 
            +
                            cache_dir=cache_dir,
         | 
| 52 | 
            +
                            **kwargs
         | 
| 53 | 
            +
                        ).eval()
         | 
| 54 |  | 
| 55 | 
             
                    # Initialize processor
         | 
| 56 | 
             
                    self.processor = AutoProcessor.from_pretrained(
         | 

