Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Model Card for Model ID
|
| 2 |
|
| 3 |
Intent classification is the act of classifying customer's in to different pre defined categories.
|
|
@@ -30,41 +34,41 @@ This is the model card of a 🤗 transformers model that has been pushed on the
|
|
| 30 |
|
| 31 |
## How to Get Started with the Model
|
| 32 |
|
| 33 |
-
class IntentClassifier:
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
|
| 39 |
|
| 40 |
-
def build_prompt(text, prompt="", company_name="", company_specific=""):
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
def predict(self, text, prompt_options, company_name, company_portion) -> str:
|
| 50 |
-
input_text = build_prompt(text, prompt_options, company_name, company_portion)
|
| 51 |
-
# print(input_text)
|
| 52 |
-
# Tokenize the concatenated inp_ut text
|
| 53 |
-
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
| 60 |
-
|
| 61 |
-
return decoded_output
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
m = IntentClassifier("serj/intent-classifier")
|
| 65 |
-
print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
|
| 66 |
-
"OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
|
| 67 |
-
"Products and subscriptions"))
|
| 68 |
|
| 69 |
|
| 70 |
[More Information Needed]
|
|
@@ -122,5 +126,4 @@ F1 AVG on the train set is 0.69
|
|
| 122 |
|
| 123 |
#### Hardware
|
| 124 |
|
| 125 |
-
Nvidia RTX3060 12Gb
|
| 126 |
-
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- intent, topic-discovery
|
| 4 |
+
---
|
| 5 |
# Model Card for Model ID
|
| 6 |
|
| 7 |
Intent classification is the act of classifying customer's in to different pre defined categories.
|
|
|
|
| 34 |
|
| 35 |
## How to Get Started with the Model
|
| 36 |
|
| 37 |
+
class IntentClassifier:
|
| 38 |
+
def __init__(self, model_name="serj/intent-classifier", device="cuda"):
|
| 39 |
+
self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
|
| 40 |
+
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 41 |
+
self.device = device
|
| 42 |
|
| 43 |
|
| 44 |
+
def build_prompt(text, prompt="", company_name="", company_specific=""):
|
| 45 |
+
if company_name == "Pizza Mia":
|
| 46 |
+
company_specific = "This company is a pizzeria place."
|
| 47 |
+
if company_name == "Online Banking":
|
| 48 |
+
company_specific = "This company is an online banking."
|
| 49 |
+
|
| 50 |
+
return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
|
| 51 |
|
|
|
|
| 52 |
|
| 53 |
+
def predict(self, text, prompt_options, company_name, company_portion) -> str:
|
| 54 |
+
input_text = build_prompt(text, prompt_options, company_name, company_portion)
|
| 55 |
+
# print(input_text)
|
| 56 |
+
# Tokenize the concatenated inp_ut text
|
| 57 |
+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
|
| 58 |
+
|
| 59 |
+
# Generate the output
|
| 60 |
+
output = self.model.generate(input_ids)
|
| 61 |
+
|
| 62 |
+
# Decode the output tokens
|
| 63 |
+
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
| 64 |
+
|
| 65 |
+
return decoded_output
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
m = IntentClassifier("serj/intent-classifier")
|
| 69 |
+
print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
|
| 70 |
+
"OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
|
| 71 |
+
"Products and subscriptions"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
[More Information Needed]
|
|
|
|
| 126 |
|
| 127 |
#### Hardware
|
| 128 |
|
| 129 |
+
Nvidia RTX3060 12Gb
|
|
|