add sanity checks
Browse files
custom_generate/generate.py
CHANGED
|
@@ -203,7 +203,8 @@ def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
|
| 203 |
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
| 204 |
)
|
| 205 |
)
|
| 206 |
-
|
|
|
|
| 207 |
raise ValueError(
|
| 208 |
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
| 209 |
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
|
|
|
| 203 |
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
| 204 |
)
|
| 205 |
)
|
| 206 |
+
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
|
| 207 |
+
if kwargs_has_arg or has_custom_gen_config_arg:
|
| 208 |
raise ValueError(
|
| 209 |
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
| 210 |
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|