ykarout commited on
Commit
5ab5bf9
·
verified ·
1 Parent(s): 7e25a53

Delete inference.ipynb

Browse files
Files changed (1) hide show
  1. inference.ipynb +0 -139
inference.ipynb DELETED
@@ -1,139 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "1a848dfb-4083-4d7c-af83-82e663d1f964",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import torch\n",
11
- "from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer\n",
12
- "\n",
13
- "MODEL_ID = \"/workspace/mixtral-reasoning-output/checkpoint-275/\"\n",
14
- "\n",
15
- "model = AutoModelForCausalLM.from_pretrained(\n",
16
- " MODEL_ID,\n",
17
- " device_map=\"auto\",\n",
18
- " torch_dtype=torch.bfloat16,\n",
19
- " attn_implementation=\"flash_attention_2\",\n",
20
- " trust_remote_code=True\n",
21
- " )\n",
22
- " \n",
23
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)"
24
- ]
25
- },
26
- {
27
- "cell_type": "code",
28
- "execution_count": null,
29
- "id": "419c3212-c843-4102-850c-ec2e83e5401a",
30
- "metadata": {
31
- "scrolled": true
32
- },
33
- "outputs": [],
34
- "source": [
35
- "model.eval()"
36
- ]
37
- },
38
- {
39
- "cell_type": "code",
40
- "execution_count": null,
41
- "id": "002d4810-e64d-4553-9256-d5a95bad07da",
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "system_prompt = \"detailed thinking on\"\n",
46
- "user_prompt = \"\"\"Triangle $ABC$ has a right angle at $B$. Points $D$ and $E$ are chosen on $\\overline{AC}$ and $\\overline{BC}$, respectively, such that $AB = BE = ED = DC = 2$. Find the area of $\\triangle CDE$.\"\"\""
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": null,
52
- "id": "c6ba5056-67ea-40d7-a6bb-fdb2d750cfc7",
53
- "metadata": {},
54
- "outputs": [],
55
- "source": [
56
- "# Fix the pad token issue\n",
57
- "if tokenizer.pad_token is None or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
58
- " tokenizer.pad_token = tokenizer.unk_token\n",
59
- " tokenizer.pad_token_id = tokenizer.unk_token_id\n",
60
- "\n",
61
- "# Verify the fix\n",
62
- "print(f\"EOS token ID: {tokenizer.eos_token_id}\")\n",
63
- "print(f\"PAD token ID: {tokenizer.pad_token_id}\")\n",
64
- "print(f\"UNK token ID: {tokenizer.unk_token_id}\")"
65
- ]
66
- },
67
- {
68
- "cell_type": "code",
69
- "execution_count": null,
70
- "id": "3af2af88-aa27-4670-a78b-bea00bc07414",
71
- "metadata": {},
72
- "outputs": [],
73
- "source": [
74
- "messages = [\n",
75
- " {\"role\": \"system\", \"content\": system_prompt},\n",
76
- " {\"role\": \"user\", \"content\": user_prompt}\n",
77
- "]\n",
78
- "\n",
79
- "# Tokenize input\n",
80
- "input_ids = tokenizer.apply_chat_template(\n",
81
- " messages,\n",
82
- " tokenize=True,\n",
83
- " add_generation_prompt=True,\n",
84
- " return_tensors=\"pt\"\n",
85
- ").to(\"cuda\")\n",
86
- "\n",
87
- "# Create streamer - TextStreamer automatically prints to stdout\n",
88
- "streamer = TextStreamer(\n",
89
- " tokenizer, \n",
90
- " skip_special_tokens=False,\n",
91
- " skip_prompt=False,\n",
92
- ")\n",
93
- "\n",
94
- "# Generate with streamer - no threading needed with TextStreamer\n",
95
- "model.generate(\n",
96
- " input_ids=input_ids,\n",
97
- " pad_token_id=tokenizer.eos_token_id\n",
98
- " streamer=streamer,\n",
99
- " max_new_tokens=16383,\n",
100
- " temperature=0.5,\n",
101
- " top_p=0.95,\n",
102
- " top_k=40,\n",
103
- " repetition_penalty=1.2,\n",
104
- " do_sample=True,\n",
105
- " #use_cache=True\n",
106
- ")"
107
- ]
108
- },
109
- {
110
- "cell_type": "code",
111
- "execution_count": null,
112
- "id": "2f4daeb6-02c6-4376-acc8-2b34fbb9fbd7",
113
- "metadata": {},
114
- "outputs": [],
115
- "source": []
116
- }
117
- ],
118
- "metadata": {
119
- "kernelspec": {
120
- "display_name": "Python 3 (ipykernel)",
121
- "language": "python",
122
- "name": "python3"
123
- },
124
- "language_info": {
125
- "codemirror_mode": {
126
- "name": "ipython",
127
- "version": 3
128
- },
129
- "file_extension": ".py",
130
- "mimetype": "text/x-python",
131
- "name": "python",
132
- "nbconvert_exporter": "python",
133
- "pygments_lexer": "ipython3",
134
- "version": "3.10.12"
135
- }
136
- },
137
- "nbformat": 4,
138
- "nbformat_minor": 5
139
- }