Update convert.py
Browse files- convert.py +3 -13
convert.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
# Copied from https://github.com/nghuyong/ERNIE-Pytorch/blob/master/convert.py
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
#!/usr/bin/env python
|
|
@@ -99,19 +100,8 @@ def extract_and_convert(input_dir, output_dir):
|
|
| 99 |
paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
|
| 100 |
for weight_name, weight_value in paddle_paddle_params.items():
|
| 101 |
if 'weight' in weight_name:
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
# if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name and \
|
| 106 |
-
# "k_proj" not in weight_name and "v_proj" not in weight_name and \
|
| 107 |
-
# "out_proj" not in weight_name and "linear1" not in weight_name and \
|
| 108 |
-
# "linear2" not in weight_name:
|
| 109 |
-
# weight_value = weight_value.transpose()
|
| 110 |
-
if "encoder" in weight_name:
|
| 111 |
-
if "linear1" in weight_name or "linear2" in weight_name:
|
| 112 |
-
weight_value = weight_value.transpose()
|
| 113 |
-
else:
|
| 114 |
-
weight_value = weight_value.transpose()
|
| 115 |
|
| 116 |
if weight_name not in weight_map:
|
| 117 |
print('=' * 20, '[SKIP]', weight_name, '=' * 20)
|
|
|
|
| 1 |
# Copied from https://github.com/nghuyong/ERNIE-Pytorch/blob/master/convert.py
|
| 2 |
+
# with some modifications for ernie-m
|
| 3 |
|
| 4 |
|
| 5 |
#!/usr/bin/env python
|
|
|
|
| 100 |
paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
|
| 101 |
for weight_name, weight_value in paddle_paddle_params.items():
|
| 102 |
if 'weight' in weight_name:
|
| 103 |
+
if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name:
|
| 104 |
+
weight_value = weight_value.transpose()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
if weight_name not in weight_map:
|
| 107 |
print('=' * 20, '[SKIP]', weight_name, '=' * 20)
|