Holy-fox commited on
Commit
5c50f51
·
verified ·
1 Parent(s): 5995b5f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +192 -75
README.md CHANGED
@@ -3,50 +3,51 @@ license: gemma
3
  ---
4
  # DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1
5
 
 
 
6
  ## Overview
7
 
8
- このモデルは、Googleの `google/gemma-3-4b-it` をベースモデルとしてファインチューニングされた日本語大規模言語モデルです。
9
- [Unsloth](https://github.com/unslothai/unsloth) を使用して効率的にトレーニングを行い、特別に作成された合成データセットを用いることで、特にユーザーの指示やプロンプトに対する追従能力の向上を目指しました。
10
 
11
- * **ベースモデル:** google/gemma-3-4b-it
12
- * **トレーニングフレームワーク:** Unsloth
13
- * **データセット:** 合成データセット(プロンプト追従能力向上目的)
14
- * **主な改善点:** プロンプトへの忠実性、指示実行能力
15
 
16
  ## How to use
17
 
18
- このモデルは、以下の方法で使用できます。
 
 
 
 
 
 
 
 
19
 
20
- ### 1. vLLMによる推論
21
 
22
- [vLLM](https://github.com/vllm-project/vllm) を使用すると、高速な推論が可能です。
23
 
24
  ```python
25
  from vllm import LLM, SamplingParams
26
 
27
- # モデル名を指定
28
  model_name = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
29
- # もしくはローカルパスを指定
30
- # model_name = "/path/to/your/local/model"
31
 
32
- llm = LLM(model=model_name, trust_remote_code=True) # Gemma-3では trust_remote_code=True が必要になる場合があります
 
 
33
 
34
- # プロンプトの準備 (Gemma-3のチャットテンプレート形式を推奨)
35
- # 例: <start_of_turn>user\n日本の首都はどこですか?<end_of_turn>\n<start_of_turn>model\n
36
- prompt = "<start_of_turn>user\nあなたの得意なことは何ですか?<end_of_turn>\n<start_of_turn>model\n"
37
 
38
- # サンプリングパラメータの設定
39
- sampling_params = SamplingParams(
40
- temperature=0.1,
41
- top_p=0.9,
42
- max_tokens=100,
43
- stop=["<end_of_turn>"] # 必要に応じて停止トークンを設定
44
- )
45
 
46
- # 推論の実行
47
  outputs = llm.generate(prompt, sampling_params)
48
 
49
- # 結果の表示
50
  for output in outputs:
51
  prompt = output.prompt
52
  generated_text = output.outputs[0].text
@@ -54,74 +55,190 @@ for output in outputs:
54
 
55
  ```
56
 
57
- ### 2. Transformersによる推論 (テキストのみ)
58
 
59
- Hugging Faceの `transformers` ライブラリを使用して推論を行う基本的なコードです。System Promptを設定することも可能です。
60
 
61
  ```python
 
62
  import torch
63
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
64
-
65
- # モデル名を指定
66
- model_name = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
67
 
68
- model = AutoModelForCausalLM.from_pretrained(
69
- model_name,
70
- torch_dtype=torch.bfloat16, # または torch.float16
71
- device_map="auto",
72
- trust_remote_code=True # Gemma-3モデルによっては必要
 
 
 
 
 
 
73
  )
74
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
75
 
76
- # プロンプトの準備 (System PromptとUser Prompt)
77
- system_prompt = "あなたは親切で正直なアシスタントです。"
78
- user_prompt = "自己紹介をしてください。"
79
-
80
- # Gemma-3のチャットテンプレートを適用
81
  messages = [
82
- {"role": "system", "content": system_prompt},
83
- {"role": "user", "content": user_prompt},
 
 
 
 
 
 
84
  ]
85
 
86
- # apply_chat_template を使用 (推奨)
87
- input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
88
 
89
- input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
 
 
 
90
 
91
- # ストリーミング出力用の設定
92
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
93
 
94
- # 推論の実行
95
- output_ids = model.generate(
96
- **input_ids,
97
- max_new_tokens=256,
98
- temperature=0.7,
99
- top_p=0.9,
100
- do_sample=True,
101
- streamer=streamer,
102
- pad_token_id=tokenizer.eos_token_id, # pad_token_idを設定
103
- eos_token_id=tokenizer.eos_token_id # eos_token_idを明示的に設定
104
- )
105
 
106
- # ストリーミングしない場合
107
- # output_ids = model.generate(
108
- # **input_ids,
109
- # max_new_tokens=256,
110
- # temperature=0.7,
111
- # top_p=0.9,
112
- # do_sample=True,
113
- # pad_token_id=tokenizer.eos_token_id,
114
- # eos_token_id=tokenizer.eos_token_id
115
  # )
116
- # generated_text = tokenizer.decode(output_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
117
- # print(generated_text)
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  ```
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  ## License
123
 
124
- このモデルは、ベースモデルである `google/gemma-3-4b-it` のライセンス条件に基づいて提供されます。
125
- `google/gemma-3-4b-it` のライセンスは [Gemma Terms of Use](https://ai.google.dev/gemma/terms) に従います。
 
126
 
127
- このモデルの使用にあたっては、ベースモデルのライセンスおよび利用規約を遵守してください。
 
 
3
  ---
4
  # DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1
5
 
6
+ このモデルは、Google の [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it) をベースモデルとしています。
7
+
8
  ## Overview
9
 
10
+ このモデルは、[Unsloth](https://github.com/unslothai/unsloth) フレームワークと合成データセットを用いてファインチューニングされました。主な目的は、**プロンプト追従能力の向上**です。ベースモデルである `google/gemma-3-4b-it` の持つマルチモーダル機能(テキストと画像の入力、テキストの出力)と多言語対応能力を継承しています。
 
11
 
12
+ このモデルは、特定の指示や文脈に対する応答精度を高めるようにトレーニングされており、対話システム、コンテンツ生成、タスク実行支援など、より的確な応答が求められる場面での利用に適しています。
 
 
 
13
 
14
  ## How to use
15
 
16
+ **注意:** 以下のコードを実行する前に、必要なライブラリをインストールしてください。特に `transformers` ライブラリは Gemma 3 をサポートするバージョン (4.50.0 以降) が必要です。また、Unsloth を使用してファインチューニングされたモデルの場合、推論時にも Unsloth が必要になる場合があります。
17
+
18
+ ```sh
19
+ pip install -U transformers accelerate torch
20
+ # vLLM を使用する場合
21
+ pip install vllm
22
+ # Unsloth が推論に必要となる場合
23
+ pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" # 環境に合わせて調整
24
+ ```
25
 
26
+ ### vLLM での推論 (テキスト生成)
27
 
28
+ vLLM を使用すると、高速なテキスト生成推論が可能です。(2025年3月現在、vLLMのGemma 3マルチモーダル対応は進行中の可能性があります。最新情報はvLLMのドキュメントをご確認ください。)
29
 
30
  ```python
31
  from vllm import LLM, SamplingParams
32
 
33
+ # モデル名を指定
34
  model_name = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
35
+ # またはローカルパスを指定
36
+ # model_name = "/path/to/your/model"
37
 
38
+ # LLMインスタンスを作成
39
+ # tensor_parallel_size は利用可能なGPU数に合わせて調整してください
40
+ llm = LLM(model=model_name, trust_remote_code=True) # Unslothモデルの場合など必要に応じて trust_remote_code=True
41
 
42
+ # サンプリングパラメータを設定
43
+ sampling_params = SamplingParams(temperature=0.1, top_p=0.95, max_tokens=200)
 
44
 
45
+ prompt = "<start_of_turn>user\n日本の首都はどこですか?<end_of_turn>\n<start_of_turn>model\n"
 
 
 
 
 
 
46
 
47
+ # 推論を実行
48
  outputs = llm.generate(prompt, sampling_params)
49
 
50
+ # 結果を表示
51
  for output in outputs:
52
  prompt = output.prompt
53
  generated_text = output.outputs[0].text
 
55
 
56
  ```
57
 
58
+ ### Transformers での推論 (テキストのみ)
59
 
60
+ `transformers` ライブラリを使用して、テキストプロンプト(システムプロンプトとユーザープロンプトを含む)に基づいてテキストを生成し��す。
61
 
62
  ```python
63
+ from transformers import pipeline, AutoTokenizer
64
  import torch
 
 
 
 
65
 
66
+ # モデル名とトークナイザーを指定
67
+ model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
68
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
69
+
70
+ # パイプラインを作成
71
+ pipe = pipeline(
72
+ "text-generation", # Gemma 3 のテキスト生成には text-generation が適切
73
+ model=model_id,
74
+ tokenizer=tokenizer, # 明示的にトークナイザーを渡す
75
+ device="cuda", # GPUが利用可能な場合
76
+ torch_dtype=torch.bfloat16 # Gemma 3 推奨のデータ型
77
  )
 
78
 
79
+ # チャット形式のメッセージを作成
 
 
 
 
80
  messages = [
81
+ {
82
+ "role": "system",
83
+ "content": "あなたは親切なアシスタントです。" # システムプロンプト
84
+ },
85
+ {
86
+ "role": "user",
87
+ "content": "Unslothとは何ですか?簡単に説明してください。" # ユーザープロンプト
88
+ }
89
  ]
90
 
91
+ # チャットテンプレートを適用
92
+ # apply_chat_template は内部で <start_of_turn>などを付与します
93
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
94
 
95
+ # 推論を実行
96
+ # max_new_tokens は生成する最大トークン数
97
+ # do_sample=True にすると、多様な応答が生成されやすくなります
98
+ outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.2, top_p=0.95)
99
 
100
+ # 生成されたテキストのみを表示 (入力プロンプト部分を除く)
101
+ generated_text = outputs[0]['generated_text'][len(prompt):]
102
+ print(generated_text)
103
 
104
+ # --- AutoModelForCausalLM を使う場合 ---
105
+ # from transformers import AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
106
 
107
+ # model = AutoModelForCausalLM.from_pretrained(
108
+ # model_id,
109
+ # torch_dtype=torch.bfloat16,
110
+ # device_map="auto", # GPUに自動で配置
111
+ # # Unslothモデルの場合、追加の引数が必要な場合があります
 
 
 
 
112
  # )
113
+ # model.eval()
114
+
115
+ # inputs = tokenizer.apply_chat_template(
116
+ # messages,
117
+ # add_generation_prompt=True,
118
+ # return_tensors="pt"
119
+ # ).to(model.device)
120
+
121
+ # input_len = inputs.shape[-1]
122
+
123
+ # with torch.inference_mode():
124
+ # generation_output = model.generate(
125
+ # inputs,
126
+ # max_new_tokens=256,
127
+ # do_sample=True,
128
+ # temperature=0.7,
129
+ # top_p=0.95,
130
+ # )
131
+ # # 入力部分を除いた生成トークンを取得
132
+ # generated_tokens = generation_output[0][input_len:]
133
+ # decoded = tokenizer.decode(generated_tokens, skip_special_tokens=True)
134
+ # print(decoded)
135
  ```
136
 
137
+ ### Transformers での推論 (画像とテキスト)
138
+
139
+ `transformers` ライブラリを使用して、画像とテキストプロンプトに基づいてテキストを生成します。
140
+
141
+ ```python
142
+ from transformers import pipeline, AutoProcessor
143
+ import torch
144
+ from PIL import Image
145
+ import requests
146
+
147
+ # モデル名、プロセッサーを指定
148
+ model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
149
+ processor = AutoProcessor.from_pretrained(model_id)
150
+
151
+ # パイプラインを作成 (image-text-to-textタスク)
152
+ pipe = pipeline(
153
+ "image-text-to-text",
154
+ model=model_id,
155
+ processor=processor, # 明示的にプロセッサーを渡す
156
+ device="cuda", # GPUが利用可能な場合
157
+ torch_dtype=torch.bfloat16 # Gemma 3 推奨のデータ型
158
+ # Unslothモデルの場合、追加の引数が必要な場合があります
159
+ )
160
+
161
+ # 画像のURL
162
+ image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
163
+ # 画像を読み込む (ローカルファイルの場合は Image.open("path/to/image.jpg") )
164
+ image = Image.open(requests.get(image_url, stream=True).raw)
165
+
166
+ # チャット形式のメッセージを作成 (画像とテキストを含む)
167
+ messages = [
168
+ {
169
+ "role": "system",
170
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
171
+ },
172
+ {
173
+ "role": "user",
174
+ "content": [
175
+ {"type": "image"}, # 画像のプレースホルダー
176
+ {"type": "text", "text": "この画像について詳しく説明してください。"} # テキストプロンプト
177
+ ]
178
+ }
179
+ ]
180
+
181
+ # 推論を実行 (images引数で画像を渡す)
182
+ # max_new_tokens は生成する最大トークン数
183
+ outputs = pipe(messages, images=image, max_new_tokens=200)
184
+
185
+ # 生成されたテキストを表示
186
+ # パイプラインの出力形式に合わせて調整が必要な場合があります
187
+ # Gemma 3の場合、最後のメッセージのcontentを取り出すことが多いです
188
+ print(outputs[0]["generated_text"][-1]["content"])
189
+
190
+ # --- Gemma3ForConditionalGeneration を使う場合 ---
191
+ # from transformers import Gemma3ForConditionalGeneration
192
+
193
+ # model = Gemma3ForConditionalGeneration.from_pretrained(
194
+ # model_id,
195
+ # torch_dtype=torch.bfloat16,
196
+ # device_map="auto" # GPUに自動で配置
197
+ # # Unslothモデルの場合、追加の引数が必要な場合があります
198
+ # ).eval()
199
+
200
+ # # 画像を含むメッセージを作成 (Imageオブジェクトを直接渡す)
201
+ # messages_for_processor = [
202
+ # {
203
+ # "role": "system",
204
+ # "content": [{"type": "text", "text": "You are a helpful assistant."}]
205
+ # },
206
+ # {
207
+ # "role": "user",
208
+ # "content": [
209
+ # {"type": "image", "image": image}, # PIL Image オブジェクト
210
+ # {"type": "text", "text": "この画像について詳しく説明してください。"}
211
+ # ]
212
+ # }
213
+ # ]
214
+
215
+ # # プロセッサーで入力を作成
216
+ # inputs = processor.apply_chat_template(
217
+ # messages_for_processor,
218
+ # add_generation_prompt=True,
219
+ # tokenize=True, # トークン化を有効に
220
+ # return_dict=True,
221
+ # return_tensors="pt"
222
+ # ).to(model.device) # モデルと同じデバイスに移動
223
+
224
+ # input_len = inputs["input_ids"].shape[-1]
225
+
226
+ # # 推論実行
227
+ # with torch.inference_mode():
228
+ # generation = model.generate(**inputs, max_new_tokens=200, do_sample=False)
229
+ # # 入力部分を除いた生成トークンを取得
230
+ # generation = generation[0][input_len:]
231
+
232
+ # # デコードして表示
233
+ # decoded = processor.decode(generation, skip_special_tokens=True)
234
+ # print(decoded)
235
+ ```
236
 
237
  ## License
238
 
239
+ このモデルは、ベースモデルである `google/gemma-3-4b-it` のライセンス条件に従います。詳細については、以下のリンクをご参照ください。
240
+
241
+ * **Gemma Terms of Use:** [https://ai.google.dev/gemma/terms](https://ai.google.dev/gemma/terms)
242
 
243
+ このモデルを利用する際は、ライセンス条件を遵守してください。
244
+ ```