Meta AI presentó a finales de 2021. XLS-R es un modelo de aprendizaje automático ("ML") para el aprendizaje de representaciones de voz en varios idiomas; y fue entrenado en más de 400.000 horas de audio de voz disponible públicamente en 128 idiomas. Tras su lanzamiento, el modelo representó un salto con respecto al modelo multilingüe de Meta AI, que se entrenó en aproximadamente 50.000 horas de audio de voz en 53 idiomas.
Esta guía explica los pasos para ajustar XLS-R para el reconocimiento automático de voz ("ASR") usando un Kaggle Notebook . El modelo estará afinado en español de Chile, pero se pueden seguir los pasos generales para afinar XLS-R en los diferentes idiomas que desee.
El modelo wav2vec2 original introducido en 2020 se entrenó previamente con 960 horas de audio de voz del conjunto de datos de Librispeech y ~53 200 horas de audio de voz del conjunto de datos de LibriVox . Tras su lanzamiento, había dos tamaños de modelo disponibles: el modelo BASE con 95 millones de parámetros y el modelo GRANDE con 317 millones de parámetros.
Hay 3 modelos XLS-R: XLS-R (0.3B) con 300 millones de parámetros, XLS-R (1B) con mil millones de parámetros y XLS-R (2B) con 2 mil millones de parámetros. Esta guía utilizará el modelo XLS-R (0.3B).
Hay algunos artículos excelentes sobre cómo ajustar los modelos wav2vev2 , y quizás sea una especie de "estándar de oro". Por supuesto, el enfoque general aquí imita lo que encontrará en otras guías. Vas a:
Como se mencionó en la Introducción , el modelo XLS-R estará afinado en español de Chile. El conjunto de datos específico es el conjunto de datos del habla del español chileno desarrollado por Guevara-Rukoz et al. Está disponible para descargar en . El conjunto de datos consta de dos subconjuntos de datos: (1) 2.636 grabaciones de audio de hablantes chilenos y (2) 1.738 grabaciones de audio de hablantes chilenas.
Cada subconjunto de datos incluye un archivo de índice line_index.tsv
. Cada línea de cada archivo de índice contiene un par de nombres de archivo de audio y una transcripción del audio en el archivo asociado, por ejemplo:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
sustituciones ( S
): Se registra un error de sustitución cuando la predicción contiene una palabra que es diferente de la palabra análoga en la referencia. Por ejemplo, esto ocurre cuando la predicción escribe mal una palabra en la referencia.
eliminaciones ( D
): Se registra un error de eliminación cuando la predicción contiene una palabra que no está presente en la referencia.
inserciones ( I
): Se registra un error de inserción cuando la predicción no contiene una palabra que esté presente en la referencia.
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
prediction: "Él está saliendo." reference: "Él está saltando."
TEXTO | PALABRA 1 | PALABRA 2 | PALABRA 3 |
---|---|---|---|
predicción | Él | esta | saliendo |
referencia | Él | esta | saltando |
| correcto | correcto | sustitución |
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
Debería ser obvio que la tasa de errores de palabras no necesariamente nos dice qué errores específicos existen. En el ejemplo anterior, WER identifica que la PALABRA 3 contiene un error en el texto predicho, pero no nos dice que los caracteres i y e están equivocados en la predicción. Se pueden utilizar otras métricas, como la tasa de error de caracteres ("CER"), para un análisis de errores más preciso.
www.wandb.com
.www.wandb.ai/authorize
.
xls-r-300m-chilean-spanish-asr
.Se utilizará un Kaggle Secret para almacenar de forma segura su clave API de WandB.
WANDB_API_KEY
en el campo Etiqueta e ingrese su clave API de WandB para el valor.WANDB_API_KEY
esté marcada.El se cargó en Kaggle como 2 conjuntos de datos distintos:
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
torchaudio
a la última versión. torchaudio
se utilizará para cargar archivos de audio y volver a muestrear datos de audio.jiwer
que se requiere para usar el método load_metric
de la biblioteca HuggingFace Datasets
que se usará más adelante.
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
transformers
HuggingFace y las clases Wav2Vec2*
asociadas proporcionan la columna vertebral de la funcionalidad utilizada para el ajuste fino.
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
La cuarta celda recupera su secreto WANDB_API_KEY
que se configuró en el Paso 2.2 . Establezca la cuarta celda en:
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
La sexta celda define métodos de utilidad para leer los archivos de índice del conjunto de datos (consulte la subsección Conjunto de datos de entrenamiento anterior), así como para limpiar el texto de transcripción y crear el vocabulario. Establezca la sexta celda en:
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
El método read_index_file_data
lee un archivo de índice del conjunto de datos line_index.tsv
y produce una lista de listas con nombres de archivos de audio y datos de transcripción, por ejemplo:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
truncate_training_dataset
trunca los datos de un archivo de índice de lista utilizando la constante NUM_LOAD_FROM_EACH_SET
establecida en el Paso 3.5 . Específicamente, la constante NUM_LOAD_FROM_EACH_SET
se usa para especificar la cantidad de muestras de audio que deben cargarse desde cada conjunto de datos. Para los fines de esta guía, el número se establece en 1600
lo que significa que eventualmente se cargarán un total de 3200
muestras de audio. Para cargar todas las muestras, establezca NUM_LOAD_FROM_EACH_SET
en el valor de cadena all
.clean_text
se utiliza para eliminar de cada transcripción de texto los caracteres especificados por la expresión regular asignada a SPECIAL_CHARS
en el Paso 3.5 . Estos caracteres, incluida la puntuación, se pueden eliminar ya que no proporcionan ningún valor semántico al entrenar el modelo para aprender asignaciones entre funciones de audio y transcripciones de texto.create_vocab
crea un vocabulario a partir de transcripciones de texto limpio. Simplemente, extrae todos los caracteres únicos del conjunto de transcripciones de texto limpias. Verá un ejemplo del vocabulario generado en el Paso 3.14 . La séptima celda define métodos de utilidad que utilizan torchaudio
para cargar y volver a muestrear datos de audio. Establezca la séptima celda en:
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
carga un archivo de audio específico y devuelve una matriz multidimensional torch.Tensor
de los datos de audio junto con la frecuencia de muestreo del audio. Todos los archivos de audio de los datos de entrenamiento tienen una frecuencia de muestreo de 48000
Hz. Esta frecuencia de muestreo "original" es capturada por la constante ORIG_SAMPLING_RATE
en el Paso 3.5 .resample
se utiliza para reducir la resolución de datos de audio desde una frecuencia de muestreo de 48000
a 16000
. wav2vec2 está previamente entrenado en audio muestreado a 16000
Hz. En consecuencia, cualquier audio utilizado para el ajuste debe tener la misma frecuencia de muestreo. En este caso, los ejemplos de audio deben reducirse de 48000
Hz a 16000
Hz. 16000
Hz son capturados por la constante TGT_SAMPLING_RATE
en el Paso 3.5 .
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
process_speech_audio
devuelve los valores de entrada de una muestra de entrenamiento proporcionada.process_target_text
codifica cada transcripción de texto como una lista de etiquetas, es decir, una lista de índices que hacen referencia a caracteres del vocabulario. Verá una codificación de muestra en el Paso 3.15 .
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
La décima celda lee los archivos de índice de datos de entrenamiento para las grabaciones de hablantes masculinos y las grabaciones de hablantes femeninas utilizando el método read_index_file_data
definido en el Paso 3.6 . Establezca la décima celda en:
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
La undécima celda trunca las listas de datos de entrenamiento utilizando el método truncate_training_dataset
definido en el Paso 3.6 . Establezca la undécima celda en:
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
NUM_LOAD_FROM_EACH_SET
establecida en el Paso 3.5 define la cantidad de muestras que se conservarán de cada conjunto de datos. La constante se establece en 1600
en esta guía para un total de 3200
muestras.
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
La decimotercera celda itera sobre cada muestra de datos de entrenamiento y limpia el texto de transcripción asociado utilizando el método clean_text
definido en el Paso 3.6 . Establezca la decimotercera celda en:
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
La decimocuarta celda crea un vocabulario utilizando las transcripciones limpias del paso anterior y el método create_vocab
definido en el Paso 3.6 . Establezca la decimocuarta celda en:
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
Puede imprimir vocab_dict
que debería producir el siguiente resultado:
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
La decimoquinta celda agrega el carácter delimitador de palabras |
al vocabulario. Establezca la decimoquinta celda en:
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
El carácter delimitador de palabras se utiliza al tokenizar transcripciones de texto como una lista de etiquetas. Específicamente, se usa para definir el final de una palabra y se usa al inicializar la clase Wav2Vec2CTCTokenizer
, como se verá en el Paso 3.17 .
Por ejemplo, la siguiente lista codifica no te entiendo nada
usando el vocabulario del Paso 3.14 :
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
Wav2Vec2CTCTokenizer
. La decimoséptima celda inicializa una instancia de Wav2Vec2CTCTokenizer
. Establezca la decimoséptima celda en:
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
Tenga en cuenta que el tokenizer
se inicializa con [UNK]
asignado a unk_token
y [PAD]
asignado a pad_token
; el primero se usa para representar tokens desconocidos en transcripciones de texto y el segundo se usa para rellenar transcripciones al crear lotes de transcripciones con diferentes longitudes. El tokenizador agregará estos dos valores al vocabulario.
La inicialización del tokenizador en este paso también agregará dos tokens adicionales al vocabulario, a saber, <s>
y /</s>
, que se utilizan para demarcar el principio y el final de las oraciones, respectivamente.
|
se asigna explícitamente a word_delimiter_token
en este paso para reflejar que el símbolo de canalización se usará para demarcar el final de las palabras de acuerdo con nuestra adición del carácter al vocabulario en el Paso 3.15 . El |
El símbolo es el valor predeterminado para word_delimiter_token
. Por lo tanto, no era necesario establecerlo explícitamente, pero se hizo en aras de la claridad.
De manera similar a word_delimiter_token
, se asigna explícitamente un único espacio a replace_word_delimiter_char
lo que refleja que el símbolo de canalización |
se utilizará para reemplazar caracteres de espacio en blanco en las transcripciones de texto. El espacio en blanco es el valor predeterminado para replace_word_delimiter_char
. Por lo tanto, tampoco era necesario establecerlo explícitamente, pero se hizo en aras de la claridad.
Puede imprimir el vocabulario completo del tokenizador llamando al método get_vocab()
en tokenizer
.
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
La decimoctava celda inicializa una instancia de Wav2Vec2FeatureExtractor
. Establezca la decimoctava celda en:
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
Wav2Vec2FeatureExtractor
son todos valores predeterminados, con la excepción de return_attention_mask
que por defecto es False
. Los valores predeterminados se muestran/aprueban para mayor claridad.feature_size
especifica el tamaño de dimensión de las funciones de entrada (es decir, funciones de datos de audio). El valor predeterminado de este parámetro es 1
.sampling_rate
le dice al extractor de funciones la frecuencia de muestreo a la que se deben digitalizar los datos de audio. Como se analizó en el Paso 3.7 , wav2vec2 está preentrenado en audio muestreado a 16000
Hz y, por lo tanto, 16000
es el valor predeterminado para este parámetro.padding_value
especifica el valor que se utiliza al rellenar datos de audio, según sea necesario al agrupar muestras de audio de diferentes longitudes. El valor predeterminado es 0.0
.do_normalize
se utiliza para especificar si los datos de entrada deben transformarse a una distribución normal estándar. El valor por defecto es True
. La documentación de la clase Wav2Vec2FeatureExtractor
señala que "[la normalización] puede ayudar a mejorar significativamente el rendimiento de algunos modelos".return_attention_mask
especifican si se debe pasar la máscara de atención o no. El valor se establece en True
para este caso de uso. La decimonovena celda inicializa una instancia de Wav2Vec2Processor
. Establezca la decimonovena celda en:
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
La clase Wav2Vec2Processor
combina tokenizer
y feature_extractor
del paso 3.17 y 3.18 respectivamente en un solo procesador.
Tenga en cuenta que la configuración del procesador se puede guardar llamando al método save_pretrained
en la instancia de la clase Wav2Vec2Processor
.
processor.save_pretrained(OUTPUT_DIR_PATH)
La vigésima celda carga cada archivo de audio especificado en la lista all_training_samples
. Establezca la vigésima celda en:
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
torch.Tensor
y se almacenan en all_input_data
como una lista de diccionarios. Cada diccionario contiene los datos de audio de una muestra particular, junto con la transcripción del texto del audio.read_audio_data
también devuelve la frecuencia de muestreo de los datos de audio. Como sabemos que la frecuencia de muestreo es 48000
Hz para todos los archivos de audio en este caso de uso, la frecuencia de muestreo se ignora en este paso.all_input_data
en un Pandas DataFrame La vigésima primera celda convierte la lista all_input_data
en un Pandas DataFrame para facilitar la manipulación de los datos. Establezca la vigésima primera celda en:
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
La vigésima segunda celda utiliza el processor
inicializado en el Paso 3.19 para extraer características de cada muestra de datos de audio y codificar cada transcripción de texto como una lista de etiquetas. Establezca la vigésima segunda celda en:
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
La vigésima tercera celda divide el marco de datos all_input_data_df
en conjuntos de datos de entrenamiento y evaluación (validación) utilizando la constante SPLIT_PCT
del paso 3.5 . Establezca la vigésima tercera celda en:
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
SPLIT_PCT
es 0.10
en esta guía, lo que significa que el 10 % de todos los datos de entrada se reservarán para evaluación y el 90 % de los datos se utilizarán para capacitación/ajuste.Dataset
La vigésima cuarta celda convierte los DataFrames train_data_df
y valid_data_df
en objetos Dataset
. Establezca la vigésima cuarta celda en:
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
Los objetos Dataset
son consumidos por instancias de la clase HuggingFace Trainer
, como verá en el Paso 3.30 .
Puede imprimir train_data
y valid_data
para ver los metadatos de ambos objetos Dataset
.
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
from_pretrained
llamado en Wav2Vec2ForCTC
especifica que queremos cargar los pesos previamente entrenados para el modelo especificado.MODEL
se especificó en el Paso 3.5 y se configuró en facebook/wav2vec2-xls-r-300m
reflejando el modelo XLS-R (0.3).ctc_loss_reduction
especifica el tipo de reducción que se aplicará a la salida de la función de pérdida de Clasificación Temporal Conexionista ("CTC"). La pérdida de CTC se utiliza para calcular la pérdida entre una entrada continua, en este caso datos de audio, y una secuencia de destino, en este caso transcripciones de texto. Al establecer el valor en mean
, las pérdidas de producción de un lote de insumos se dividirán por las longitudes objetivo. Luego se calcula la media del lote y la reducción se aplica a los valores de pérdida.pad_token_id
especifica el token que se utilizará para el relleno al realizar el procesamiento por lotes. Se establece en el ID [PAD]
establecido al inicializar el tokenizador en el Paso 3.17 .vocab_size
define el tamaño del vocabulario del modelo. Es el tamaño del vocabulario después de la inicialización del tokenizador en el Paso 3.17 y refleja el número de nodos de la capa de salida de la parte directa de la red.
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
La vigésima séptima celda inicializa los argumentos de entrenamiento que se pasarán a una instancia Trainer
. Establezca la vigésima séptima celda en:
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
TrainingArguments
acepta más de .save_safetensors
cuando False
especifica que el modelo ajustado debe guardarse en un archivo pickle
en lugar de usar el formato safetensors
.group_by_length
cuando es True
indica que se deben agrupar muestras de aproximadamente la misma longitud. Esto minimiza el acolchado y mejora la eficiencia del entrenamiento.per_device_train_batch_size
establece el número de muestras por minilote de entrenamiento. Este parámetro se establece en 18
mediante la constante TRAIN_BATCH_SIZE
asignada en el Paso 3.5 . Esto implica 160 pasos por época.per_device_eval_batch_size
establece el número de muestras por minilote de evaluación (reserva). Este parámetro se establece en 10
mediante la constante EVAL_BATCH_SIZE
asignada en el Paso 3.5 .num_train_epochs
establece el número de épocas de entrenamiento. Este parámetro se establece en 30
mediante la constante TRAIN_EPOCHS
asignada en el Paso 3.5 . Esto implica 4.800 pasos totales durante el entrenamiento.gradient_checkpointing
cuando True
ayuda a ahorrar memoria al controlar los cálculos de gradiente, pero da como resultado pasos hacia atrás más lentos.evaluation_strategy
cuando se establece en steps
significa que la evaluación se realizará y registrará durante el entrenamiento en un intervalo especificado por el parámetro eval_steps
.logging_strategy
cuando se establece en steps
significa que las estadísticas de ejecución de entrenamiento se registrarán en un intervalo especificado por el parámetro logging_steps
.save_strategy
cuando se establece en steps
significa que se guardará un punto de control del modelo ajustado en un intervalo especificado por el parámetro save_steps
.eval_steps
establece el número de pasos entre evaluaciones de datos reservados. Este parámetro se establece en 100
mediante la constante EVAL_STEPS
asignada en el Paso 3.5 .save_steps
establece el número de pasos después de los cuales se guarda un punto de control del modelo ajustado. Este parámetro se establece en 3200
mediante la constante SAVE_STEPS
asignada en el Paso 3.5 .logging_steps
establece el número de pasos entre registros de estadísticas de ejecución de entrenamiento. Este parámetro se establece en 100
mediante la constante LOGGING_STEPS
asignada en el Paso 3.5 .learning_rate
establece la tasa de aprendizaje inicial. Este parámetro se establece en 1e-4
mediante la constante LEARNING_RATE
asignada en el Paso 3.5 .warmup_steps
establece el número de pasos para calentar linealmente la tasa de aprendizaje desde 0 hasta el valor establecido por learning_rate
. Este parámetro se establece en 800
mediante la constante WARMUP_STEPS
asignada en el Paso 3.5 .
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
Trainer
que se inicializará momentáneamente en el Paso 3.30 . Dado que las secuencias de entrada y las secuencias de etiquetas varían en longitud en cada minilote, algunas secuencias deben rellenarse para que todas tengan la misma longitud.DataCollatorCTCWithPadding
rellena dinámicamente datos de mini lotes. El parámetro padding
cuando se establece en True
especifica que las secuencias de funciones de entrada de audio y las secuencias de etiquetas más cortas deben tener la misma longitud que la secuencia más larga en un mini lote.0.0
establecido al inicializar el extractor de funciones en el Paso 3.18 .-100
para que estas etiquetas se ignoren al calcular la métrica WER.
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
La trigésima celda inicializa una instancia de la clase Trainer
. Establezca la trigésima celda en:
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
Trainer
se inicializa con:model
previamente entrenado se inicializó en el Paso 3.25 .train_data
Dataset
del paso 3.24 .valid_data
Dataset
del paso 3.24 .tokenizer
se asigna a processor.feature_extractor
y funciona con data_collator
para rellenar automáticamente las entradas hasta la entrada de longitud máxima de cada mini lote. La trigésima primera celda llama al método train
en la instancia de la clase Trainer
para ajustar el modelo. Establezca la trigésima primera celda en:
### CELL 31: Finetune the model ### trainer.train()
La celda trigésimo segunda es la última celda del cuaderno. Guarda el modelo ajustado llamando al método save_model
en la instancia Trainer
. Establezca la celda de treinta segundos en:
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
Configure el portátil Kaggle para que se ejecute con el acelerador NVIDIA GPU P100 .
El modelo ajustado se enviará al directorio de Kaggle especificado por la constante OUTPUT_DIR_PATH
especificada en el Paso 3.5 . La salida del modelo debe incluir los siguientes archivos:
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
Estos archivos se pueden descargar localmente. Además, puede crear un nuevo modelo Kaggle utilizando los archivos del modelo. El modelo Kaggle se utilizará con la guía de inferencia complementaria para ejecutar inferencias en el modelo ajustado.