Commit
·
bbc6d7c
1
Parent(s):
7d2a362
feat: disable flash attn if not supported CUDA version or device capability
Browse files- modeling_clip.py +19 -0
modeling_clip.py
CHANGED
|
@@ -144,6 +144,25 @@ def _resolve_attention_libs(config: JinaCLIPConfig):
|
|
| 144 |
'for installation instructions, disabling'
|
| 145 |
)
|
| 146 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
return True
|
| 148 |
return False
|
| 149 |
|
|
|
|
| 144 |
'for installation instructions, disabling'
|
| 145 |
)
|
| 146 |
return False
|
| 147 |
+
major, minor, *_ = torch.version.cuda.split('.')
|
| 148 |
+
major, minor = int(major), int(minor)
|
| 149 |
+
if major < 11 or (major == 11 and minor < 7):
|
| 150 |
+
warnings.warn(
|
| 151 |
+
'Flash attention requires CUDA>=11.7. Found version '
|
| 152 |
+
f'{major}.{minor}, disabling'
|
| 153 |
+
)
|
| 154 |
+
return False
|
| 155 |
+
capability = torch.cuda.get_device_capability()
|
| 156 |
+
major, *_ = capability
|
| 157 |
+
major = int(major)
|
| 158 |
+
if major < 8:
|
| 159 |
+
device_name = torch.cuda.get_device_properties(0).name
|
| 160 |
+
warnings.warn(
|
| 161 |
+
'Flash attention requires device capability>=8.0 (NVIDIA Ampere, '
|
| 162 |
+
f'Hopper or ADA). Found device {device_name} with capability '
|
| 163 |
+
f'{capability}, disabling'
|
| 164 |
+
)
|
| 165 |
+
return False
|
| 166 |
return True
|
| 167 |
return False
|
| 168 |
|