Update README.md
Browse files
README.md
CHANGED
|
@@ -91,4 +91,76 @@ template = """{
|
|
| 91 |
prediction = predict_NuExtract(model, tokenizer, [text], template)[0]
|
| 92 |
print(prediction)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
```
|
|
|
|
| 91 |
prediction = predict_NuExtract(model, tokenizer, [text], template)[0]
|
| 92 |
print(prediction)
|
| 93 |
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
Sliding window prompting:
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
import json
|
| 100 |
+
|
| 101 |
+
MAX_INPUT_SIZE = 20_000
|
| 102 |
+
MAX_NEW_TOKENS = 6000
|
| 103 |
+
|
| 104 |
+
def clean_json_text(text):
|
| 105 |
+
text = text.strip()
|
| 106 |
+
text = text.replace("\#", "#").replace("\&", "&")
|
| 107 |
+
return text
|
| 108 |
+
|
| 109 |
+
def predict_chunk(text, template, current, model, tokenizer):
|
| 110 |
+
current = clean_json_text(current)
|
| 111 |
+
|
| 112 |
+
input_llm = f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{"
|
| 113 |
+
input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
|
| 114 |
+
output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True)
|
| 115 |
+
|
| 116 |
+
return clean_json_text(output.split("<|output|>")[1])
|
| 117 |
+
|
| 118 |
+
def split_document(document, window_size, overlap):
|
| 119 |
+
tokens = tokenizer.tokenize(document)
|
| 120 |
+
print(f"\tLength of document: {len(tokens)} tokens")
|
| 121 |
+
|
| 122 |
+
chunks = []
|
| 123 |
+
if len(tokens) > window_size:
|
| 124 |
+
for i in range(0, len(tokens), window_size-overlap):
|
| 125 |
+
print(f"\t{i} to {i + len(tokens[i:i + window_size])}")
|
| 126 |
+
chunk = tokenizer.convert_tokens_to_string(tokens[i:i + window_size])
|
| 127 |
+
chunks.append(chunk)
|
| 128 |
+
|
| 129 |
+
if i + len(tokens[i:i + window_size]) >= len(tokens):
|
| 130 |
+
break
|
| 131 |
+
else:
|
| 132 |
+
chunks.append(document)
|
| 133 |
+
print(f"\tSplit into {len(chunks)} chunks")
|
| 134 |
+
|
| 135 |
+
return chunks
|
| 136 |
+
|
| 137 |
+
def handle_broken_output(pred, prev):
|
| 138 |
+
try:
|
| 139 |
+
if all([(v in ["", []]) for v in json.loads(pred).values()]):
|
| 140 |
+
# if empty json, return previous
|
| 141 |
+
pred = prev
|
| 142 |
+
except:
|
| 143 |
+
# if broken json, return previous
|
| 144 |
+
pred = prev
|
| 145 |
+
|
| 146 |
+
return pred
|
| 147 |
+
|
| 148 |
+
def sliding_window_prediction(text, template, model, tokenizer, window_size=4000, overlap=128):
|
| 149 |
+
# split text into chunks of n tokens
|
| 150 |
+
tokens = tokenizer.tokenize(text)
|
| 151 |
+
chunks = split_document(text, window_size, overlap)
|
| 152 |
+
|
| 153 |
+
# iterate over text chunks
|
| 154 |
+
prev = template
|
| 155 |
+
for i, chunk in enumerate(chunks):
|
| 156 |
+
print(f"Processing chunk {i}...")
|
| 157 |
+
pred = predict_chunk(chunk, template, prev, model, tokenizer)
|
| 158 |
+
|
| 159 |
+
# handle broken output
|
| 160 |
+
pred = handle_broken_output(pred, prev)
|
| 161 |
+
|
| 162 |
+
# iterate
|
| 163 |
+
prev = pred
|
| 164 |
+
|
| 165 |
+
return pred
|
| 166 |
```
|