Meta AI a introduit fin 2021. XLS-R est un modèle d'apprentissage automatique (« ML ») pour l'apprentissage des représentations vocales multilingues ; et il a été formé sur plus de 400 000 heures d’audio vocal accessible au public dans 128 langues. Lors de sa sortie, le modèle représentait un progrès par rapport au modèle multilingue de Meta AI, qui a été formé sur environ 50 000 heures d'audio vocal dans 53 langues.
Ce guide explique les étapes pour affiner XLS-R pour la reconnaissance vocale automatique (« ASR ») à l'aide d'un Kaggle Notebook . Le modèle sera affiné sur l'espagnol chilien, mais les étapes générales peuvent être suivies pour affiner XLS-R sur les différentes langues de votre choix.
Le modèle wav2vec2 original introduit en 2020 a été pré-entraîné sur 960 heures d'audio vocal de l'ensemble de données Librispeech et environ 53 200 heures d'audio vocal de l'ensemble de données LibriVox . Dès sa sortie, deux tailles de modèle étaient disponibles : le modèle BASE avec 95 millions de paramètres et le modèle LARGE avec 317 millions de paramètres.
Il existe 3 modèles XLS-R : XLS-R (0,3B) avec 300 millions de paramètres, XLS-R (1B) avec 1 milliard de paramètres et XLS-R (2B) avec 2 milliards de paramètres. Ce guide utilisera le modèle XLS-R (0,3B).
Il existe d'excellents articles sur la façon d'affiner les modèles wav2vev2 , étant peut-être une sorte de « étalon-or ». Bien entendu, l’approche générale ici imite ce que vous trouverez dans d’autres guides. Vous serez:
Comme mentionné dans l' introduction , le modèle XLS-R sera affiné sur l'espagnol chilien. L'ensemble de données spécifique est l' ensemble de données vocales chiliennes et espagnoles développé par Guevara-Rukoz et al. Il est disponible en téléchargement sur . L'ensemble de données se compose de deux sous-ensembles de données : (1) 2 636 enregistrements audio de locuteurs chiliens de sexe masculin et (2) 1 738 enregistrements audio de locuteurs chiliens féminins.
Chaque sous-ensemble de données comprend un fichier d'index line_index.tsv
. Chaque ligne de chaque fichier d'index contient une paire d'un nom de fichier audio et d'une transcription de l'audio dans le fichier associé, par exemple :
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
substitutions ( S
) : Une erreur de substitution est enregistrée lorsque la prédiction contient un mot différent du mot analogue dans la référence. Par exemple, cela se produit lorsque la prédiction orthographie mal un mot dans la référence.
suppressions ( D
) : Une erreur de suppression est enregistrée lorsque la prédiction contient un mot qui n'est pas présent dans la référence.
insertions ( I
) : Une erreur d'insertion est enregistrée lorsque la prédiction ne contient pas de mot présent dans la référence.
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
prediction: "Él está saliendo." reference: "Él está saltando."
TEXTE | MOT 1 | MOT 2 | MOT 3 |
---|---|---|---|
prédiction | Él | c'est | saliendo |
référence | Él | c'est | saltando |
| correct | correct | substitution |
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
Il devrait être évident que le taux d’erreurs sur les mots ne nous indique pas nécessairement quelles erreurs spécifiques existent. Dans l'exemple ci-dessus, WER identifie que le MOT 3 contient une erreur dans le texte prédit, mais il ne nous indique pas que les caractères i et e sont erronés dans la prédiction. D'autres mesures, telles que le taux d'erreur de caractères (« CER »), peuvent être utilisées pour une analyse plus précise des erreurs.
www.wandb.com
.www.wandb.ai/authorize
.
xls-r-300m-chilean-spanish-asr
.Un Kaggle Secret sera utilisé pour stocker en toute sécurité votre clé API WandB.
WANDB_API_KEY
dans le champ Étiquette et entrez votre clé API WandB pour la valeur.WANDB_API_KEY
est cochée.L' a été téléchargé sur Kaggle sous la forme de 2 ensembles de données distincts :
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
torchaudio
vers la dernière version. torchaudio
sera utilisé pour charger des fichiers audio et rééchantillonner les données audio.jiwer
qui est requis pour utiliser la méthode load_metric
de la bibliothèque HuggingFace Datasets
utilisée ultérieurement.
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
transformers
HuggingFace et les classes Wav2Vec2*
associées constituent l'épine dorsale des fonctionnalités utilisées pour le réglage fin.
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
La quatrième cellule récupère votre secret WANDB_API_KEY
qui a été défini à l'étape 2.2 . Définissez la quatrième cellule sur :
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
La sixième cellule définit les méthodes utilitaires pour lire les fichiers d'index de l'ensemble de données (voir la sous-section Ensemble de données de formation ci-dessus), ainsi que pour nettoyer le texte de transcription et créer le vocabulaire. Définissez la sixième cellule sur :
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
La méthode read_index_file_data
lit un fichier d'index de jeu de données line_index.tsv
et produit une liste de listes avec le nom du fichier audio et les données de transcription, par exemple :
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
truncate_training_dataset
tronque les données d'un fichier d'index de liste à l'aide de la constante NUM_LOAD_FROM_EACH_SET
définie à l'étape 3.5 . Plus précisément, la constante NUM_LOAD_FROM_EACH_SET
est utilisée pour spécifier le nombre d'échantillons audio à charger à partir de chaque ensemble de données. Pour les besoins de ce guide, le nombre est fixé à 1600
ce qui signifie qu'un total de 3200
échantillons audio seront finalement chargés. Pour charger tous les échantillons, définissez NUM_LOAD_FROM_EACH_SET
sur la valeur de chaîne all
.clean_text
est utilisée pour supprimer chaque transcription de texte des caractères spécifiés par l'expression régulière attribuée à SPECIAL_CHARS
à l'étape 3.5 . Ces caractères, y compris la ponctuation, peuvent être éliminés car ils ne fournissent aucune valeur sémantique lors de la formation du modèle pour apprendre les mappages entre les fonctionnalités audio et les transcriptions de texte.create_vocab
crée un vocabulaire à partir de transcriptions de texte épurées. Simplement, il extrait tous les caractères uniques de l'ensemble des transcriptions de texte nettoyées. Vous verrez un exemple du vocabulaire généré à l'étape 3.14 . La septième cellule définit les méthodes utilitaires utilisant torchaudio
pour charger et rééchantillonner les données audio. Définissez la septième cellule sur :
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
charge un fichier audio spécifié et renvoie une matrice multidimensionnelle torch.Tensor
des données audio ainsi que la fréquence d'échantillonnage de l'audio. Tous les fichiers audio des données d'entraînement ont un taux d'échantillonnage de 48000
Hz. Ce taux d'échantillonnage "original" est capturé par la constante ORIG_SAMPLING_RATE
à l'étape 3.5 .resample
est utilisée pour sous-échantillonner les données audio à partir d'un taux d'échantillonnage de 48000
à 16000
. wav2vec2 est pré-entraîné sur de l'audio échantillonné à 16000
Hz. Par conséquent, tout audio utilisé pour le réglage fin doit avoir le même taux d’échantillonnage. Dans ce cas, les exemples audio doivent être sous-échantillonnés de 48000
Hz à 16000
Hz. 16000
Hz sont capturés par la constante TGT_SAMPLING_RATE
à l'étape 3.5 .
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
process_speech_audio
renvoie les valeurs d'entrée d'un échantillon de formation fourni.process_target_text
code chaque transcription de texte sous la forme d'une liste d'étiquettes, c'est-à-dire une liste d'indices faisant référence aux caractères du vocabulaire. Vous verrez un exemple de codage à l'étape 3.15 .
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
La dixième cellule lit les fichiers d'index de données de formation pour les enregistrements de locuteurs masculins et les enregistrements de locuteurs féminins à l'aide de la méthode read_index_file_data
définie à l' étape 3.6 . Définissez la dixième cellule sur :
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
La onzième cellule tronque les listes de données d'entraînement à l'aide de la méthode truncate_training_dataset
définie à l' étape 3.6 . Définissez la onzième cellule sur :
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
NUM_LOAD_FROM_EACH_SET
définie à l'étape 3.5 définit la quantité d'échantillons à conserver pour chaque ensemble de données. La constante est fixée à 1600
dans ce guide pour un total de 3200
échantillons.
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
La treizième cellule parcourt chaque échantillon de données d'apprentissage et nettoie le texte de transcription associé à l'aide de la méthode clean_text
définie à l' étape 3.6 . Définissez la treizième cellule sur :
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
La quatorzième cellule crée un vocabulaire en utilisant les transcriptions nettoyées de l'étape précédente et la méthode create_vocab
définie à l'étape 3.6 . Définissez la quatorzième cellule sur :
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
Vous pouvez imprimer vocab_dict
qui devrait produire le résultat suivant :
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
La quinzième cellule ajoute le caractère délimiteur de mot |
au vocabulaire. Définissez la quinzième cellule sur :
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
Le caractère délimiteur de mot est utilisé lors de la tokenisation des transcriptions de texte sous forme de liste d'étiquettes. Plus précisément, il est utilisé pour définir la fin d'un mot et il est utilisé lors de l'initialisation de la classe Wav2Vec2CTCTokenizer
, comme nous le verrons à l'étape 3.17 .
Par exemple, la liste suivante code no te entiendo nada
en utilisant le vocabulaire de l'étape 3.14 :
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
Wav2Vec2CTCTokenizer
. La dix-septième cellule initialise une instance de Wav2Vec2CTCTokenizer
. Définissez la dix-septième cellule sur :
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
Notez que le tokenizer
est initialisé avec [UNK]
attribué à unk_token
et [PAD]
attribué à pad_token
, le premier étant utilisé pour représenter des jetons inconnus dans les transcriptions de texte et le second étant utilisé pour compléter les transcriptions lors de la création de lots de transcriptions de différentes longueurs. Ces deux valeurs seront ajoutées au vocabulaire par le tokenizer.
L'initialisation du tokenizer à cette étape ajoutera également deux jetons supplémentaires au vocabulaire, à savoir <s>
et /</s>
qui sont utilisés pour délimiter respectivement le début et la fin des phrases.
|
est attribué explicitement à word_delimiter_token
dans cette étape pour refléter que le symbole de pipe sera utilisé pour délimiter la fin des mots conformément à notre ajout du caractère au vocabulaire à l'étape 3.15 . Le |
symbol est la valeur par défaut de word_delimiter_token
. Il n’était donc pas nécessaire de le définir explicitement, mais cela a été fait par souci de clarté.
De la même manière qu'avec word_delimiter_token
, un seul espace est explicitement attribué à replace_word_delimiter_char
ce qui reflète que le symbole pipe |
sera utilisé pour remplacer les espaces vides dans les transcriptions de texte. L'espace vide est la valeur par défaut pour replace_word_delimiter_char
. Il n’était donc pas nécessaire non plus de le définir explicitement, mais cela a été fait par souci de clarté.
Vous pouvez imprimer le vocabulaire complet du tokenizer en appelant la méthode get_vocab()
sur tokenizer
.
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
La dix-huitième cellule initialise une instance de Wav2Vec2FeatureExtractor
. Définissez la dix-huitième cellule sur :
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
Wav2Vec2FeatureExtractor
sont toutes des valeurs par défaut, à l'exception de return_attention_mask
qui est par défaut False
. Les valeurs par défaut sont affichées/transmises par souci de clarté.feature_size
spécifie la taille des dimensions des entités d'entrée (c'est-à-dire les entités de données audio). La valeur par défaut de ce paramètre est 1
.sampling_rate
indique à l'extracteur de fonctionnalités la fréquence d'échantillonnage à laquelle les données audio doivent être numérisées. Comme indiqué à l'étape 3.7 , wav2vec2 est pré-entraîné sur de l'audio échantillonné à 16000
Hz et donc 16000
est la valeur par défaut pour ce paramètre.padding_value
spécifie la valeur utilisée lors du remplissage des données audio, comme requis lors du regroupement d'échantillons audio de différentes longueurs. La valeur par défaut est 0.0
.do_normalize
est utilisé pour spécifier si les données d'entrée doivent être transformées en une distribution normale standard. La valeur par défaut est True
. La documentation de la classe Wav2Vec2FeatureExtractor
indique que « [la normalisation] peut aider à améliorer considérablement les performances de certains modèles ».return_attention_mask
spécifient si le masque d'attention doit être transmis ou non. La valeur est définie sur True
pour ce cas d'utilisation. La dix-neuvième cellule initialise une instance de Wav2Vec2Processor
. Définissez la dix-neuvième cellule sur :
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
La classe Wav2Vec2Processor
combine tokenizer
et feature_extractor
de l'étape 3.17 et de l'étape 3.18 respectivement en un seul processeur.
Notez que la configuration du processeur peut être enregistrée en appelant la méthode save_pretrained
sur l'instance de classe Wav2Vec2Processor
.
processor.save_pretrained(OUTPUT_DIR_PATH)
La vingtième cellule charge chaque fichier audio spécifié dans la liste all_training_samples
. Définissez la vingtième cellule sur :
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
torch.Tensor
et stockées dans all_input_data
sous forme de liste de dictionnaires. Chaque dictionnaire contient les données audio d'un échantillon particulier, ainsi que la transcription textuelle de l'audio.read_audio_data
renvoie également la fréquence d'échantillonnage des données audio. Puisque nous savons que la fréquence d'échantillonnage est 48000
Hz pour tous les fichiers audio dans ce cas d'utilisation, la fréquence d'échantillonnage est ignorée dans cette étape.all_input_data
en un Pandas DataFrame La vingt et unième cellule convertit la liste all_input_data
en Pandas DataFrame pour faciliter la manipulation des données. Définissez la vingt et unième cellule sur :
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
La vingt-deuxième cellule utilise le processor
initialisé à l'étape 3.19 pour extraire les caractéristiques de chaque échantillon de données audio et pour coder chaque transcription de texte sous la forme d'une liste d'étiquettes. Définissez la vingt-deuxième cellule sur :
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
La vingt-troisième cellule divise le DataFrame all_input_data_df
en ensembles de données de formation et d'évaluation (validation) à l'aide de la constante SPLIT_PCT
de l'étape 3.5 . Définissez la vingt-troisième cellule sur :
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
SPLIT_PCT
est de 0.10
dans ce guide, ce qui signifie que 10 % de toutes les données d'entrée seront conservées pour évaluation et 90 % des données seront utilisées pour la formation/le réglage fin.Dataset
La vingt-quatrième cellule convertit les DataFrames train_data_df
et valid_data_df
en objets Dataset
. Définissez la vingt-quatrième cellule sur :
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
Les objets Dataset
sont consommés par les instances de classe HuggingFace Trainer
, comme vous le verrez à l'étape 3.30 .
Vous pouvez imprimer train_data
et valid_data
pour afficher les métadonnées des deux objets Dataset
.
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
from_pretrained
appelée sur Wav2Vec2ForCTC
spécifie que nous souhaitons charger les poids pré-entraînés pour le modèle spécifié.MODEL
a été spécifiée à l'étape 3.5 et a été définie sur facebook/wav2vec2-xls-r-300m
reflétant le modèle XLS-R (0,3).ctc_loss_reduction
spécifie le type de réduction à appliquer à la sortie de la fonction de perte Connectionist Temporal Classification (« CTC »). La perte CTC est utilisée pour calculer la perte entre une entrée continue, dans ce cas des données audio, et une séquence cible, dans ce cas des transcriptions de texte. En définissant la valeur sur mean
, les pertes de sortie pour un lot d'entrées seront divisées par les longueurs cibles. La moyenne sur le lot est ensuite calculée et la réduction est appliquée aux valeurs de perte.pad_token_id
spécifie le jeton à utiliser pour le remplissage lors du traitement par lots. Il est défini sur l'identifiant [PAD]
défini lors de l'initialisation du tokenizer à l'étape 3.17 .vocab_size
définit la taille du vocabulaire du modèle. Il s'agit de la taille du vocabulaire après l'initialisation du tokenizer à l'étape 3.17 et reflète le nombre de nœuds de couche de sortie de la partie avant du réseau.
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
La vingt-septième cellule initialise les arguments de formation qui seront transmis à une instance Trainer
. Définissez la vingt-septième cellule sur :
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
TrainingArguments
accepte plus de .save_safetensors
lorsque False
spécifie que le modèle affiné doit être enregistré dans un fichier pickle
au lieu d'utiliser le format safetensors
.group_by_length
lorsque True
indique que les échantillons d'approximativement la même longueur doivent être regroupés. Cela minimise le rembourrage et améliore l'efficacité de l'entraînement.per_device_train_batch_size
définit le nombre d'échantillons par mini-lot d'entraînement. Ce paramètre est défini sur 18
via la constante TRAIN_BATCH_SIZE
attribuée à l' étape 3.5 . Cela implique 160 étapes par époque.per_device_eval_batch_size
définit le nombre d'échantillons par mini-lot d'évaluation (holdout). Ce paramètre est défini sur 10
via la constante EVAL_BATCH_SIZE
attribuée à l'étape 3.5 .num_train_epochs
définit le nombre d'époques d'entraînement. Ce paramètre est défini sur 30
via la constante TRAIN_EPOCHS
attribuée à l'étape 3.5 . Cela implique 4 800 pas au total pendant l'entraînement.gradient_checkpointing
lorsque True
permet d'économiser de la mémoire en vérifiant les calculs de gradient, mais entraîne des passes arrière plus lentes.evaluation_strategy
lorsqu'il est défini sur steps
signifie que l'évaluation sera effectuée et enregistrée pendant l'entraînement à un intervalle spécifié par le paramètre eval_steps
.logging_strategy
lorsqu'il est défini sur steps
signifie que les statistiques d'exécution de formation seront enregistrées à un intervalle spécifié par le paramètre logging_steps
.save_strategy
lorsqu'il est défini sur steps
signifie qu'un point de contrôle du modèle affiné sera enregistré à un intervalle spécifié par le paramètre save_steps
.eval_steps
définit le nombre d'étapes entre les évaluations des données d'exclusion. Ce paramètre est défini sur 100
via la constante EVAL_STEPS
attribuée à l'étape 3.5 .save_steps
définit le nombre d'étapes après lesquelles un point de contrôle du modèle affiné est enregistré. Ce paramètre est défini sur 3200
via la constante SAVE_STEPS
attribuée à l' étape 3.5 .logging_steps
définit le nombre d'étapes entre les journaux de statistiques d'exécution de formation. Ce paramètre est défini sur 100
via la constante LOGGING_STEPS
attribuée à l'étape 3.5 .learning_rate
définit le taux d'apprentissage initial. Ce paramètre est défini sur 1e-4
via la constante LEARNING_RATE
attribuée à l'étape 3.5 .warmup_steps
définit le nombre d'étapes pour réchauffer linéairement le taux d'apprentissage de 0 à la valeur définie par learning_rate
. Ce paramètre est défini sur 800
via la constante WARMUP_STEPS
attribuée à l'étape 3.5 .
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
Trainer
qui sera initialisée momentanément à l' étape 3.30 . Étant donné que la longueur des séquences d'entrée et des séquences d'étiquettes varie dans chaque mini-lot, certaines séquences doivent être complétées afin qu'elles aient toutes la même longueur.DataCollatorCTCWithPadding
complète dynamiquement les données par mini-lots. Le paramètre padding
, lorsqu'il est défini sur True
spécifie que les séquences de fonctionnalités d'entrée audio et les séquences d'étiquettes plus courtes doivent avoir la même longueur que la séquence la plus longue d'un mini-lot.0.0
définie lors de l'initialisation de l'extracteur de fonctionnalités à l'étape 3.18 .-100
afin que ces étiquettes soient ignorées lors du calcul de la métrique WER.
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
La trentième cellule initialise une instance de la classe Trainer
. Définissez la trentième cellule sur :
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
Trainer
est initialisée avec :model
pré-entraîné initialisé à l'étape 3.25 .train_data
Dataset
de l'étape 3.24 .Dataset
valid_data
de l'étape 3.24 .tokenizer
est attribué à processor.feature_extractor
et fonctionne avec data_collator
pour compléter automatiquement les entrées jusqu'à l'entrée de longueur maximale de chaque mini-lot. La trente et unième cellule appelle la méthode train
sur l'instance de classe Trainer
pour affiner le modèle. Définissez la trente et unième cellule sur :
### CELL 31: Finetune the model ### trainer.train()
La trente-deuxième cellule est la dernière cellule du bloc-notes. Il enregistre le modèle affiné en appelant la méthode save_model
sur l'instance Trainer
. Définissez la trente-deuxième cellule sur :
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
Configurez le Kaggle Notebook pour qu'il s'exécute avec l'accélérateur NVIDIA GPU P100 .
Le modèle affiné sera affiché dans le répertoire Kaggle spécifié par la constante OUTPUT_DIR_PATH
spécifiée à l'étape 3.5 . La sortie du modèle doit inclure les fichiers suivants :
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
Ces fichiers peuvent être téléchargés localement. De plus, vous pouvez créer un nouveau modèle Kaggle à l'aide des fichiers de modèle. Le modèle Kaggle sera utilisé avec le guide d'inférence associé pour exécuter l'inférence sur le modèle affiné.