adricl commited on
Commit
2a062a8
·
verified ·
1 Parent(s): 18dc639

Upload HuggingFace_Mistral_Transformer_Single_Instrument.ipynb

Browse files
HuggingFace_Mistral_Transformer_Single_Instrument.ipynb ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {
7
+ "id": "SiTIpPjArIyr"
8
+ },
9
+ "source": [
10
+ "# Using Midi traning data and MidiTok Remi to generate music with Mistral model \n",
11
+ "# split music into Single Instrument and split into 1024\n"
12
+ ]
13
+ },
14
+ {
15
+ "attachments": {},
16
+ "cell_type": "markdown",
17
+ "metadata": {
18
+ "id": "gOd93yV0sGd2"
19
+ },
20
+ "source": [
21
+ "## Setup Environment"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "To compile Symusic \n",
31
+ "\n",
32
+ "Get g++11 or higher\n",
33
+ "\n",
34
+ "git clone --recursive https://github.com/Yikai-Liao/symusic\n",
35
+ "CXX=/usr/bin/g++-11 pip install ./symusic\n"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {
42
+ "cellView": "form",
43
+ "id": "fX12Yquyuihc"
44
+ },
45
+ "outputs": [],
46
+ "source": [
47
+ "\n",
48
+ "\n",
49
+ "from copy import deepcopy\n",
50
+ "from pathlib import Path\n",
51
+ "from random import shuffle\n",
52
+ "\n",
53
+ "from evaluate import load as load_metric\n",
54
+ "from miditok import REMI, TokenizerConfig, TokTrainingIterator\n",
55
+ "from miditok.pytorch_data import DatasetMIDI, DataCollator\n",
56
+ "from miditok.utils import split_files_for_training\n",
57
+ "\n",
58
+ "from miditok.data_augmentation import augment_dataset\n",
59
+ "from torch import Tensor, argmax\n",
60
+ "from torch.utils.data import DataLoader\n",
61
+ "from torch.cuda import is_available as cuda_available, is_bf16_supported\n",
62
+ "from torch.backends.mps import is_available as mps_available\n",
63
+ "from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig\n",
64
+ "from transformers.trainer_utils import set_seed\n",
65
+ "from tqdm import tqdm"
66
+ ]
67
+ },
68
+ {
69
+ "attachments": {},
70
+ "cell_type": "markdown",
71
+ "metadata": {},
72
+ "source": [
73
+ "## Setup Tokenizer"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "# Seed\n",
83
+ "set_seed(777)\n",
84
+ "\n",
85
+ "# Our tokenizer's configuration\n",
86
+ "BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}\n",
87
+ "TOKENIZER_PARAMS = {\n",
88
+ " \"pitch_range\": (21, 109),\n",
89
+ " \"beat_res\": BEAT_RES,\n",
90
+ " \"num_velocities\": 24,\n",
91
+ " \"special_tokens\": [\"PAD\", \"BOS\", \"EOS\"],\n",
92
+ " \"use_chords\": True,\n",
93
+ " \"use_rests\": True,\n",
94
+ " \"use_tempos\": True,\n",
95
+ " \"use_time_signatures\": True,\n",
96
+ " \"use_programs\": False, # We want single track \n",
97
+ " \"one_token_stream_for_programs\": True,\n",
98
+ " \"programs\": list(range(0, 128)), #-1 drums, skip drums\n",
99
+ " \"num_tempos\": 32,\n",
100
+ " \"tempo_range\": (50, 200), # (min_tempo, max_tempo)\n",
101
+ "}\n",
102
+ "config = TokenizerConfig(**TOKENIZER_PARAMS)\n",
103
+ "\n",
104
+ "# Creates the tokenizer REMI PLUS\n",
105
+ "tokenizer = REMI(config)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "# Load Midi filed and train the the tokenizer on the midi files"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "root_data_dir = Path('/home/wombat/Documents/projects/music/midiTok/data/')\n",
122
+ "root_save = Path(root_data_dir / 'HuggingFace_Mistral_Transformer_Single_Instrument')\n",
123
+ "\n",
124
+ "tokenizer_name = \"HuggingFace_Mistral_Transformer_Single_Instrument.json\""
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "\n",
134
+ "# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 30k tokens\n",
135
+ "data_dirs = [\"adl-piano-midi\", \"maestro-v3.0.0\", \"musicnet_midis\" ] # for single \n",
136
+ "midi_paths = []\n",
137
+ "for data_dir in data_dirs:\n",
138
+ " path = Path(root_data_dir / 'Traning Data' / data_dir)\n",
139
+ " midi_paths.extend(list(path.resolve().glob(\"**/*.mid\")) + list(path.resolve().glob(\"**/*.midi\")))\n",
140
+ "\n",
141
+ "print(f\"Found {len(midi_paths)} MIDI files\")"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "#Note the size of the dataset is quite large, so it requires a huge amount of memory to train the tokenizer for 61749 files it took 64gb of memory\n",
151
+ "tokenizer.train(\n",
152
+ " vocab_size=32000,\n",
153
+ " files_paths=midi_paths,\n",
154
+ ")\n",
155
+ "tokenizer.save(root_save / tokenizer_name)\n",
156
+ "\n"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "tokenizer = REMI(params=Path(root_save / tokenizer_name))"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "metadata": {},
171
+ "source": [
172
+ "## Prepare MIDIs for training\n",
173
+ "\n",
174
+ "Here we split the files in three subsets: train, validation and test.\n",
175
+ "Then data augmentation is performed on each subset independently, and the MIDIs are split into smaller chunks that make approximately the desired token sequence length for training."
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "# Split MIDI paths in train/valid/test sets\n",
185
+ "total_num_files = len(midi_paths)\n",
186
+ "num_files_valid = round(total_num_files * 0.15)\n",
187
+ "num_files_test = round(total_num_files * 0.15)\n",
188
+ "shuffle(midi_paths)\n",
189
+ "midi_paths_valid = midi_paths[:num_files_valid]\n",
190
+ "midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test]\n",
191
+ "midi_paths_train = midi_paths[num_files_valid + num_files_test:]\n",
192
+ "\n",
193
+ "\n",
194
+ "\n",
195
+ "# Chunk MIDIs and perform data augmentation on each subset independently\n",
196
+ "for files_paths, subset_name in (\n",
197
+ " (midi_paths_train, \"train\"), (midi_paths_valid, \"valid\"), (midi_paths_test, \"test\")\n",
198
+ "):\n",
199
+ "\n",
200
+ " # Split the MIDIs into chunks of sizes approximately about 1024 tokens\n",
201
+ " subset_chunks_dir = root_save / f\"Maestro_{subset_name}\"\n",
202
+ " print(subset_chunks_dir)\n",
203
+ " split_files_for_training(\n",
204
+ " files_paths=files_paths,\n",
205
+ " tokenizer=tokenizer,\n",
206
+ " save_dir=subset_chunks_dir,\n",
207
+ " max_seq_len=1024,\n",
208
+ " num_overlap_bars=2,\n",
209
+ " )\n",
210
+ "\n",
211
+ " if subset_name == 'train':\n",
212
+ " print(\"Augmentation\")\n",
213
+ " # Perform data augmentation\n",
214
+ " augment_dataset(\n",
215
+ " subset_chunks_dir,\n",
216
+ " pitch_offsets=[-12, 12],\n",
217
+ " velocity_offsets=[-4, 4],\n",
218
+ " duration_offsets=[-0.5, 0.5],\n",
219
+ " )\n"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": null,
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "# Create Dataset and Collator for training\n",
229
+ "midi_paths_train = list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_train\")).glob(\"**/*.midi\"))\n",
230
+ "midi_paths_valid = list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_valid\")).glob(\"**/*.midi\")) \n",
231
+ "midi_paths_test = list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.mid\")) + list(root_save.joinpath(Path(\"Maestro_test\")).glob(\"**/*.midi\"))\n",
232
+ "\n",
233
+ "kwargs_dataset = {\"max_seq_len\": 1024, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"]}\n",
234
+ "\n",
235
+ "dataset_train = DatasetMIDI(midi_paths_train, **kwargs_dataset)\n",
236
+ "dataset_valid = DatasetMIDI(midi_paths_valid, **kwargs_dataset)\n",
237
+ "dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)\n",
238
+ "print (len(midi_paths_train), len(midi_paths_valid), len(midi_paths_test))"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "markdown",
243
+ "metadata": {},
244
+ "source": [
245
+ "# Preview files data load and split"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {
252
+ "tags": [
253
+ "Generate Preview Files"
254
+ ]
255
+ },
256
+ "outputs": [],
257
+ "source": [
258
+ "testing_files = \n",
259
+ "preview_files_path = []\n",
260
+ "for testing_file in testing_files:\n",
261
+ " preview_files_path.append(Path(testing_file))\n",
262
+ "\n",
263
+ "preview_dir = Path(root_save / \"preview\")\n",
264
+ "split_files_for_training(\n",
265
+ " files_paths=preview_files_path,\n",
266
+ " tokenizer=tokenizer,\n",
267
+ " save_dir=preview_dir,\n",
268
+ " max_seq_len=1024,\n",
269
+ " num_overlap_bars=2,\n",
270
+ " )\n"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": null,
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "valid_midi_path = root_save / \"Maestro_valid\"\n",
280
+ "midi_split_preview = list(valid_midi_path.resolve().glob(\"**/*.mid\")) + list(valid_midi_path.resolve().glob(\"**/*.midi\"))\n",
281
+ "\n",
282
+ "print(len(midi_split_preview))\n",
283
+ "file_name_lookup = []\n",
284
+ "def func_to_get_labels(p1, p2, p3):\n",
285
+ " if p3.name not in file_name_lookup:\n",
286
+ " file_name_lookup.append(p3.name)\n",
287
+ " return file_name_lookup.index(p3.name)\n",
288
+ " \n",
289
+ "kwargs_dataset = {\"max_seq_len\": 1024, \"tokenizer\": tokenizer, \"bos_token_id\": tokenizer[\"BOS_None\"], \"eos_token_id\": tokenizer[\"EOS_None\"], \"func_to_get_labels\" : func_to_get_labels}\n",
290
+ "dataset_preview = DatasetMIDI(midi_split_preview, **kwargs_dataset)"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "markdown",
295
+ "metadata": {},
296
+ "source": [
297
+ "# Save and Load datasets"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": null,
303
+ "metadata": {},
304
+ "outputs": [],
305
+ "source": [
306
+ "dataset_dir = root_save / \"data\"\n",
307
+ "dataset_dir.mkdir(parents=True, exist_ok=True)"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "import torch\n",
317
+ "torch.save(dataset_train, Path(dataset_dir / \"dataset_train.pt\"))\n",
318
+ "torch.save(dataset_valid, Path(dataset_dir / \"dataset_valid.pt\"))\n",
319
+ "torch.save(dataset_test, Path(dataset_dir / \"dataset_test.pt\"))\n"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": null,
325
+ "metadata": {},
326
+ "outputs": [],
327
+ "source": [
328
+ "import torch\n",
329
+ "dataset_train = torch.load(Path(dataset_dir / \"dataset_train.pt\"))\n",
330
+ "dataset_valid = torch.load(Path(dataset_dir / \"dataset_valid.pt\"))\n",
331
+ "dataset_test = torch.load(Path(dataset_dir / \"dataset_test.pt\"))\n",
332
+ "\n"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": null,
338
+ "metadata": {},
339
+ "outputs": [],
340
+ "source": [
341
+ "print(dataset_train[0])\n"
342
+ ]
343
+ },
344
+ {
345
+ "attachments": {},
346
+ "cell_type": "markdown",
347
+ "metadata": {},
348
+ "source": [
349
+ "## Model initialization\n",
350
+ "\n",
351
+ "We will use the [Mistral implementation of Hugging Face](https://huggingface.co/docs/transformers/model_doc/mistral).\n",
352
+ "Feel free to explore the documentation and source code to dig deeper.\n",
353
+ "\n",
354
+ "**You may need to adjust the model's configuration, the training configuration and the maximum input sequence length (cell above) depending on your hardware.**"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "# Creates model\n",
364
+ "model_config = MistralConfig(\n",
365
+ " vocab_size=len(tokenizer),\n",
366
+ " hidden_size=512,\n",
367
+ " intermediate_size=2048,\n",
368
+ " num_hidden_layers=8,\n",
369
+ " num_attention_heads=8,\n",
370
+ " num_key_value_heads=4,\n",
371
+ " sliding_window=256,\n",
372
+ " max_position_embeddings=8192,\n",
373
+ " pad_token_id=tokenizer['PAD_None'],\n",
374
+ " bos_token_id=tokenizer['BOS_None'],\n",
375
+ " eos_token_id=tokenizer['EOS_None'],\n",
376
+ ")\n",
377
+ "model = AutoModelForCausalLM.from_config(model_config)"
378
+ ]
379
+ },
380
+ {
381
+ "attachments": {},
382
+ "cell_type": "markdown",
383
+ "metadata": {},
384
+ "source": [
385
+ "## Model training"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": null,
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": [
394
+ "model_dir = root_save / 'run'\n",
395
+ "model_dir_str = str(model_dir)\n",
396
+ "print(model_dir)"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": null,
402
+ "metadata": {},
403
+ "outputs": [],
404
+ "source": [
405
+ "metrics = {metric: load_metric(metric) for metric in [\"accuracy\"]}\n",
406
+ "\n",
407
+ "def compute_metrics(eval_pred):\n",
408
+ " \"\"\"\n",
409
+ " Compute metrics for pretraining.\n",
410
+ "\n",
411
+ " Must use preprocess_logits function that converts logits to predictions (argmax or sampling).\n",
412
+ "\n",
413
+ " :param eval_pred: EvalPrediction containing predictions and labels\n",
414
+ " :return: metrics\n",
415
+ " \"\"\"\n",
416
+ " predictions, labels = eval_pred\n",
417
+ " not_pad_mask = labels != -100\n",
418
+ " labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]\n",
419
+ " return metrics[\"accuracy\"].compute(predictions=predictions.flatten(), references=labels.flatten())\n",
420
+ "\n",
421
+ "def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:\n",
422
+ " \"\"\"\n",
423
+ " Preprocess the logits before accumulating them during evaluation.\n",
424
+ "\n",
425
+ " This allows to significantly reduce the memory usage and make the training tractable.\n",
426
+ " \"\"\"\n",
427
+ " pred_ids = argmax(logits, dim=-1) # long dtype\n",
428
+ " return pred_ids\n",
429
+ "\n",
430
+ "# Create config for the Trainer\n",
431
+ "USE_CUDA = cuda_available()\n",
432
+ "print(USE_CUDA)\n",
433
+ "if not cuda_available():\n",
434
+ " FP16 = FP16_EVAL = BF16 = BF16_EVAL = False\n",
435
+ "elif is_bf16_supported():\n",
436
+ " BF16 = BF16_EVAL = True\n",
437
+ " FP16 = FP16_EVAL = False\n",
438
+ "else:\n",
439
+ " BF16 = BF16_EVAL = False\n",
440
+ " FP16 = FP16_EVAL = True\n",
441
+ "USE_MPS = not USE_CUDA and mps_available()\n",
442
+ "training_config = TrainingArguments(\n",
443
+ " model_dir_str, False, True, True, False, \"steps\",\n",
444
+ " per_device_train_batch_size=30, #76% @ 24 batch size #76% @ 32 batch size try 64 batch size next time \n",
445
+ " per_device_eval_batch_size=30, #was 24 now 32\n",
446
+ " gradient_accumulation_steps=3, #change this to 4\n",
447
+ " eval_accumulation_steps=None,\n",
448
+ " eval_steps=1000,\n",
449
+ " learning_rate=1e-4,\n",
450
+ " weight_decay=0.01,\n",
451
+ " max_grad_norm=3.0,\n",
452
+ " max_steps=20000,\n",
453
+ " lr_scheduler_type=\"cosine_with_restarts\",\n",
454
+ " warmup_ratio=0.3,\n",
455
+ " log_level=\"debug\",\n",
456
+ " logging_strategy=\"steps\",\n",
457
+ " logging_steps=20,\n",
458
+ " save_strategy=\"steps\",\n",
459
+ " save_steps=1000,\n",
460
+ " save_total_limit=5,\n",
461
+ " no_cuda=not USE_CUDA,\n",
462
+ " seed=444,\n",
463
+ " fp16=FP16,\n",
464
+ " fp16_full_eval=FP16_EVAL,\n",
465
+ " bf16=BF16,\n",
466
+ " bf16_full_eval=BF16_EVAL,\n",
467
+ " load_best_model_at_end=True,\n",
468
+ " label_smoothing_factor=0.,\n",
469
+ " optim=\"adamw_torch\",\n",
470
+ " report_to=[\"tensorboard\"],\n",
471
+ " gradient_checkpointing=True,\n",
472
+ " dataloader_num_workers=8, #added to fix trashing isssue with the gpu not having enough data to process\n",
473
+ " dataloader_pin_memory=True, #we want the dataset in memory\n",
474
+ " torch_compile=True #added to speed up \n",
475
+ " \n",
476
+ ")\n",
477
+ "\n",
478
+ "collator = DataCollator(tokenizer[\"PAD_None\"], copy_inputs_as_labels=True)\n",
479
+ "trainer = Trainer(\n",
480
+ " model=model,\n",
481
+ " args=training_config,\n",
482
+ " data_collator=collator,\n",
483
+ " train_dataset=dataset_train,\n",
484
+ " eval_dataset=dataset_valid,\n",
485
+ " compute_metrics=compute_metrics,\n",
486
+ " callbacks=None,\n",
487
+ " preprocess_logits_for_metrics=preprocess_logits,\n",
488
+ ")\n",
489
+ "\n"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": [
498
+ "# Training\n",
499
+ "train_result = trainer.train()\n",
500
+ "trainer.save_model() # Saves the tokenizer too\n",
501
+ "trainer.log_metrics(\"train\", train_result.metrics)\n",
502
+ "trainer.save_metrics(\"train\", train_result.metrics)\n",
503
+ "trainer.save_state()"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": null,
509
+ "metadata": {},
510
+ "outputs": [],
511
+ "source": [
512
+ "model.create_model_card(tags=[\"mistral\", \"midi\", \"miditok\", \"music\", \"instrument\"],\n",
513
+ " model_name=\"Mistral_MidiTok_Transformer_Single_Instrument_Small\")"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "metadata": {},
520
+ "outputs": [],
521
+ "source": [
522
+ "\n",
523
+ "model.hub_model_id = \"adricl/midi_single_instrument_mistral_transformer\"\n",
524
+ "\n",
525
+ "model.push_to_hub(commit_message=\"Training Basic Model for Mistral MidiTok Transformer Single Instrument Small\", repo_id=\"adricl/midi_single_instrument_mistral_transformer\",\n",
526
+ " token=\"\")\n"
527
+ ]
528
+ },
529
+ {
530
+ "cell_type": "markdown",
531
+ "metadata": {},
532
+ "source": [
533
+ "# For Tensorboard tensorboard --logdir runs/"
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "code",
538
+ "execution_count": null,
539
+ "metadata": {},
540
+ "outputs": [],
541
+ "source": [
542
+ "config = AutoConfig.from_pretrained(str(model_dir / \"config.json\"))\n",
543
+ "model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=str(model_dir / \"model.safetensors\"), from_tf=False, config=config)"
544
+ ]
545
+ },
546
+ {
547
+ "attachments": {},
548
+ "cell_type": "markdown",
549
+ "metadata": {},
550
+ "source": [
551
+ "## Generate music"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "metadata": {
558
+ "cellView": "form",
559
+ "id": "OaNkGcFo9UP_"
560
+ },
561
+ "outputs": [],
562
+ "source": [
563
+ "# for single track midi files splits \n",
564
+ "\n",
565
+ "gen_results_path = root_save / 'gen_res'\n",
566
+ "gen_results_path.mkdir(parents=True, exist_ok=True)\n",
567
+ "generation_config = GenerationConfig(\n",
568
+ " max_new_tokens=200, # extends samples by 200 tokens\n",
569
+ " num_beams=1, # no beam search\n",
570
+ " do_sample=True, # but sample instead\n",
571
+ " temperature=0.9,\n",
572
+ " top_k=15,\n",
573
+ " top_p=0.95,\n",
574
+ " epsilon_cutoff=3e-4,\n",
575
+ " eta_cutoff=1e-3,\n",
576
+ " pad_token_id=tokenizer.pad_token_id,\n",
577
+ ")\n",
578
+ "\n",
579
+ "# Here the sequences are padded to the left, so that the last token along the time dimension\n",
580
+ "# is always the last token of each seq, allowing to efficiently generate by batch\n",
581
+ "collator.pad_on_left = True\n",
582
+ "collator.eos_token = None\n",
583
+ "dataloader_test = DataLoader(dataset_preview, batch_size=24, collate_fn=collator)\n",
584
+ "model.eval()\n",
585
+ "count = 0\n",
586
+ "for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)\n",
587
+ " print(batch)\n",
588
+ " res = model.generate(\n",
589
+ " inputs=batch[\"input_ids\"].to(model.device),\n",
590
+ " attention_mask=batch[\"attention_mask\"].to(model.device),\n",
591
+ " generation_config=generation_config) # (N,T)\n",
592
+ "\n",
593
+ "\n",
594
+ " # Saves the generated music, as MIDI files and tokens (json)\n",
595
+ " for prompt, continuation in zip(batch[\"input_ids\"], res):\n",
596
+ " generated = continuation[len(prompt):]\n",
597
+ " midi = tokenizer.decode([deepcopy(generated.tolist())])\n",
598
+ " tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths\n",
599
+ " tokens = [seq.tolist() for seq in tokens]\n",
600
+ " for tok_seq in tokens[1:]:\n",
601
+ " _midi = tokenizer.decode([deepcopy(tok_seq)])\n",
602
+ " midi.tracks.append(_midi.tracks[0])\n",
603
+ " \n",
604
+ " file_name = file_name_lookup[count]\n",
605
+ " print(file_name)\n",
606
+ " midi.tracks[0].name = f'Continuation of original sample ({len(generated)} tokens) Original file {file_name}'\n",
607
+ " midi.tracks[1].name = f'Original sample ({len(prompt)} tokens)'\n",
608
+ " if (len(midi.tracks) > 2):\n",
609
+ " midi.tracks[2].name = f'Original sample and continuation'\n",
610
+ " midi.dump_midi(gen_results_path / f'{count}_{file_name}.mid')\n",
611
+ " tokenizer.save_tokens(tokens, gen_results_path / f'{count}_{file_name}.json') \n",
612
+ "\n",
613
+ " count += 1"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": null,
619
+ "metadata": {},
620
+ "outputs": [],
621
+ "source": [
622
+ "print(file_name_lookup)"
623
+ ]
624
+ }
625
+ ],
626
+ "metadata": {
627
+ "accelerator": "GPU",
628
+ "colab": {
629
+ "collapsed_sections": [],
630
+ "machine_shape": "hm",
631
+ "name": "Optimus_VIRTUOSO_Multi_Instrumental_RGA_Edition.ipynb",
632
+ "private_outputs": true,
633
+ "provenance": []
634
+ },
635
+ "kernelspec": {
636
+ "display_name": "Python 3 (ipykernel)",
637
+ "language": "python",
638
+ "name": "python3"
639
+ },
640
+ "language_info": {
641
+ "codemirror_mode": {
642
+ "name": "ipython",
643
+ "version": 3
644
+ },
645
+ "file_extension": ".py",
646
+ "mimetype": "text/x-python",
647
+ "name": "python",
648
+ "nbconvert_exporter": "python",
649
+ "pygments_lexer": "ipython3",
650
+ "version": "3.9.5"
651
+ },
652
+ "vscode": {
653
+ "interpreter": {
654
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
655
+ }
656
+ }
657
+ },
658
+ "nbformat": 4,
659
+ "nbformat_minor": 4
660
+ }