visit
Meta AI introduced at the end of 2021. XLS-R is a machine learning ("ML") model for cross-lingual speech representations learning; and it was trained on over 400,000 hours of publicly available speech audio across 128 languages. Upon its release, the model represented a leap over Meta AI's cross-lingual model which was trained on approximately 50,000 hours of speech audio across 53 languages.
This guide explains the steps to finetune XLS-R for automatic speech recognition ("ASR") using a Kaggle Notebook. The model will be finetuned on Chilean Spanish, but the general steps can be followed to finetune XLS-R on different languages that you desire.
The original wav2vec2 model introduced in 2020 was pretrained on 960 hours of Librispeech dataset speech audio and ~53,200 hours of LibriVox dataset speech audio. Upon its release, two model sizes were available: the BASE model with 95 million parameters and the LARGE model with 317 million parameters.
There are 3 XLS-R models: XLS-R (0.3B) with 300 million parameters, XLS-R (1B) with 1 billion parameters, and XLS-R (2B) with 2 billion parameters. This guide will use the XLS-R (0.3B) model.
There are some great write-ups on how to finetune wav2vev2 models, with perhaps being a "gold standard" of sorts. Of course, the general approach here mimics what you will find in other guides. You will:
As mentioned in the Introduction, the XLS-R model will be finetuned on Chilean Spanish. The specific dataset is the Chilean Spanish Speech Data Set developed by Guevara-Rukoz et al. It is available for download on . The dataset consists of two sub-datasets: (1) 2,636 audio recordings of Chilean male speakers and (2) 1,738 audio recordings of Chilean female speakers.
Each sub-dataset includes a line_index.tsv
index file. Each line of each index file contains a pair of an audio filename and a transcription of the audio in the associated file, e.g.:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche
clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
substitutions (S
): A substitution error is recorded when the prediction contains a word that is different from the analogous word in the reference. For example, this occurs when the prediction mis-spells a word in the reference.
deletions (D
): A deletion error is recorded when the prediction contains a word that is not present in the reference.
insertions (I
): An insertion error is recorded when the prediction does not contain a word that is present in the reference.
WER = (S + D + I)/N
where:
S = number of substition errors
D = number of deletion errors
I = number of insertion errors
N = number of words in the reference
prediction: "Él está saliendo."
reference: "Él está saltando."
TEXT | WORD 1 | WORD 2 | WORD 3 |
---|---|---|---|
prediction | Él | está | saliendo |
reference | Él | está | saltando |
|
correct |
correct |
substitution |
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
It should be obvious that the Word Error Rate does not necessarily tell us what specific errors exist. In the example above, WER identifies that WORD 3 contains an error in the predicted text, but it doesn't tell us that the characters i and e are wrong in the prediction. Other metrics, such as the Character Error Rate ("CER"), can be used for more precise error analysis.
www.wandb.com
.www.wandb.ai/authorize
.
xls-r-300m-chilean-spanish-asr
.A Kaggle Secret will be used to securely store your WandB API key.
WANDB_API_KEY
in the Label field and enter your WandB API key for the value.WANDB_API_KEY
label field is checked.The has been uploaded to Kaggle as 2 distinct datasets:
### CELL 1: Install Packages ###
!pip install --upgrade torchaudio
!pip install jiwer
torchaudio
package to the latest version. torchaudio
will be used to load audio files and resample audio data.jiwer
package which is required to use the HuggingFace Datasets
library load_metric
method used later.
### CELL 2: Import Python packages ###
import wandb
from kaggle_secrets import UserSecretsClient
import math
import re
import numpy as np
import pandas as pd
import torch
import torchaudio
import json
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass
from datasets import Dataset, load_metric, load_dataset, Audio
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments
from transformers import Trainer
transformers
library and associated Wav2Vec2*
classes provide the backbone of the functionality used for finetuning.
### CELL 3: Load WER metric ###
wer_metric = load_metric("wer")
The fourth cell retrieves your WANDB_API_KEY
secret that was set in Step 2.2. Set the fourth cell to:
### CELL 4: Login to WandB ###
user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
wandb.login(key = wandb_api_key)
### CELL 5: Constants ###
# Training data
TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/"
TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/"
EXT = ".wav"
NUM_LOAD_FROM_EACH_SET = 1600
# Vocabulary
VOCAB_FILE_PATH = "/kaggle/working/"
SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\‘\’\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]"
# Sampling rates
ORIG_SAMPLING_RATE = 48000
TGT_SAMPLING_RATE = 16000
# Training/validation data split
SPLIT_PCT = 0.10
# Model parameters
MODEL = "facebook/wav2vec2-xls-r-300m"
USE_SAFETENSORS = False
# Training arguments
OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr"
TRAIN_BATCH_SIZE = 18
EVAL_BATCH_SIZE = 10
TRAIN_EPOCHS = 30
SAVE_STEPS = 3200
EVAL_STEPS = 100
LOGGING_STEPS = 100
LEARNING_RATE = 1e-4
WARMUP_STEPS = 800
The sixth cell defines utility methods for reading the dataset index files (see the Training Dataset sub-section above), as well as for cleaning transcription text and creating the vocabulary. Set the sixth cell to:
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ###
def read_index_file_data(path: str, filename: str):
data = []
with open(path + filename, "r", encoding = "utf8") as f:
lines = f.readlines()
for line in lines:
file_and_text = line.split("\t")
data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")])
return data
def truncate_training_dataset(dataset: list) -> list:
if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower():
return
else:
return dataset[:NUM_LOAD_FROM_EACH_SET]
def clean_text(text: str) -> str:
cleaned_text = re.sub(SPECIAL_CHARS, "", text)
cleaned_text = cleaned_text.lower()
return cleaned_text
def create_vocab(data):
vocab_list = []
for index in range(len(data)):
text = data[index][1]
words = text.split(" ")
for word in words:
chars = list(word)
for char in chars:
if char not in vocab_list:
vocab_list.append(char)
return vocab_list
The read_index_file_data
method reads a line_index.tsv
dataset index file and produces a list of lists with audio filename and transcription data, e.g.:
[
["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"]
...
]
truncate_training_dataset
method truncates a list index file data using the NUM_LOAD_FROM_EACH_SET
constant set in Step 3.5. Specifically, the NUM_LOAD_FROM_EACH_SET
constant is used to specify the number of audio samples that should be loaded from each dataset. For the purposes of this guide, the number is set at 1600
which means a total of 3200
audio samples will eventually be loaded. To load all samples, set NUM_LOAD_FROM_EACH_SET
to the string value all
.clean_text
method is used to strip each text transcription of the characters specified by the regular expression assigned to SPECIAL_CHARS
in Step 3.5. These characters, inclusive of punctuation, can be eliminated as they don't provide any semantic value when training the model to learn mappings between audio features and text transcriptions.create_vocab
method creates a vocabulary from clean text transcriptions. Simply, it extracts all unique characters from the set of cleaned text transcriptions. You will see an example of the generated vocabulary in Step 3.14.The seventh cell defines utility methods using torchaudio
to load and resample audio data. Set the seventh cell to:
### CELL 7: Utility methods for loading and resampling audio data ###
def read_audio_data(file):
speech_array, sampling_rate = torchaudio.load(file, normalize = True)
return speech_array, sampling_rate
def resample(waveform):
transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE)
waveform = transform(waveform)
return waveform[0]
read_audio_data
method loads a specified audio file and returns a torch.Tensor
multi-dimensional matrix of the audio data along with the sampling rate of the audio. All the audio files in the training data have a sampling rate of 48000
Hz. This "original" sampling rate is captured by the constant ORIG_SAMPLING_RATE
in Step 3.5.resample
method is used to downsample audio data from a sampling rate of 48000
to 16000
. wav2vec2 is pretrained on audio sampled at 16000
Hz. Accordingly, any audio used for finetuning must have the same sampling rate. In this case, the audio examples must be downsampled from 48000
Hz to 16000
Hz. 16000
Hz is captured by the constant TGT_SAMPLING_RATE
in Step 3.5.
### CELL 8: Utility methods to prepare input data for training ###
def process_speech_audio(speech_array, sampling_rate):
input_values = processor(speech_array, sampling_rate = sampling_rate).input_values
return input_values[0]
def process_target_text(target_text):
with processor.as_target_processor():
encoding = processor(target_text).input_ids
return encoding
process_speech_audio
method returns the input values from a supplied training sample.process_target_text
method encodes each text transcription as a list of labels - i.e. a list of indices referring to characters in the vocabulary. You will see a sample encoding in Step 3.15.
### CELL 9: Utility method to calculate Word Error Rate
def compute_wer(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis = -1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
label_str = processor.batch_decode(pred.label_ids, group_tokens = False)
wer = wer_metric.compute(predictions = pred_str, references = label_str)
return {"wer": wer}
The tenth cell reads the training data index files for the recordings of male speakers and the recordings of female speakers using the read_index_file_data
method defined in Step 3.6. Set the tenth cell to:
### CELL 10: Read training data ###
training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv")
training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
The eleventh cell truncates the training data lists using the truncate_training_dataset
method defined in Step 3.6. Set the eleventh cell to:
### CELL 11: Truncate training data ###
training_samples_male_cl = truncate_training_dataset(training_samples_male_cl)
training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
NUM_LOAD_FROM_EACH_SET
constant set in Step 3.5 defines the quantity of samples to keep from each dataset. The constant is set to 1600
in this guide for a total of 3200
samples.
### CELL 12: Combine training samples data ###
all_training_samples = training_samples_male_cl + training_samples_female_cl
The thirteenth cell iterates over each training data sample and cleans the associated transcription text using the clean_text
method defined in Step 3.6. Set the thirteenth cell to:
for index in range(len(all_training_samples)):
all_training_samples[index][1] = clean_text(all_training_samples[index][1])
The fourteenth cell creates a vocabulary using the cleaned transcriptions from the previous step and the create_vocab
method defined in Step 3.6. Set the fourteenth cell to:
### CELL 14: Create vocabulary ###
vocab_list = create_vocab(all_training_samples)
vocab_dict = {v: i for i, v in enumerate(vocab_list)}
You can print vocab_dict
which should produce the following output:
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
The fifteenth cell adds the word delimiter character |
to the vocabulary. Set the fifteenth cell to:
### CELL 15: Add word delimiter to vocabulary ###
vocab_dict["|"] = len(vocab_dict)
The word delimiter character is used when tokenizing text transcriptions as a list of labels. Specifically, it is used to define the end of a word and it is used when initializing the Wav2Vec2CTCTokenizer
class, as will be seen in Step 3.17.
For example, the following list encodes no te entiendo nada
using the vocabulary from Step 3.14:
# Encoded text
[6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1]
# Vocabulary
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
### CELL 16: Export vocabulary ###
with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file:
json.dump(vocab_dict, vocab_file)
Wav2Vec2CTCTokenizer
class.The seventeenth cell initializes an instance of Wav2Vec2CTCTokenizer
. Set the seventeenth cell to:
### CELL 17: Initialize tokenizer ###
tokenizer = Wav2Vec2CTCTokenizer(
VOCAB_FILE_PATH + "vocab.json",
unk_token = "[UNK]",
pad_token = "[PAD]",
word_delimiter_token = "|",
replace_word_delimiter_char = " "
)
Note that the tokenizer
is initialized with [UNK]
assigned to unk_token
and [PAD]
assigned to pad_token
, with the former used to represent unknown tokens in text transcriptions and the latter used to pad transcriptions when creating batches of transcriptions with different lengths. These two values will be added to the vocabulary by the tokenizer.
Initialization of the tokenizer in this step will also add two additional tokens to the vocabulary, namely <s>
and /</s>
which are used to demarcate the beginning and end of sentences respectively.
|
is assigned to word_delimiter_token
explicitly in this step to reflect that the pipe symbol will be used to demarcate the end of words in accordance with our addition of the character to the vocabulary in Step 3.15. The |
symbol is the default value for word_delimiter_token
. So, it did not need to be explicitly set but was done so for the sake of clarity.
Similarly as with word_delimiter_token
, a single space is explicitly assigned to replace_word_delimiter_char
reflecting that the pipe symbol |
will be used to replace blank space characters in text transcriptions. Blank space is the default value for replace_word_delimiter_char
. So, it also did not need to be explicitly set but was done so for the sake of clarity.
You can print the full tokenizer vocabulary by calling the get_vocab()
method on tokenizer
.
vocab = tokenizer.get_vocab()
print(vocab)
# Output:
{'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
The eighteenth cell initializes an instance of Wav2Vec2FeatureExtractor
. Set the eighteenth cell to:
### CELL 18: Initialize feature extractor ###
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size = 1,
sampling_rate = 16000,
padding_value = 0.0,
do_normalize = True,
return_attention_mask = True
)
Wav2Vec2FeatureExtractor
initializer are all default values, with the exception of return_attention_mask
which defaults to False
. The default values are shown/passed for the sake of clarity.feature_size
parameter specifies the dimension size of input features (i.e. audio data features). This default value of this parameter is 1
.sampling_rate
tells the feature extractor the sampling rate at which the audio data should be digitalized. As discussed in Step 3.7, wav2vec2 is pretrained on audio sampled at 16000
Hz and hence 16000
is the default value for this parameter.padding_value
parameter specifies the value that is used when padding audio data, as required when batching audio samples of different lengths. The default value is 0.0
.do_normalize
is used to specify if input data should be transformed to a standard normal distribution. The default value is True
. Wav2Vec2FeatureExtractor
class documentation notes that "[normalizing] can help to significantly improve the performance for some models."return_attention_mask
parameters specifies if the attention mask should be passed or not. The value is set to True
for this use case.The nineteenth cell initializes an instance of Wav2Vec2Processor
. Set the nineteenth cell to:
### CELL 19: Initialize processor ###
processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
The Wav2Vec2Processor
class combines tokenizer
and feature_extractor
from Step 3.17 and Step 3.18 respectively into a single processor.
Note that the processor configuration can be saved by calling the save_pretrained
method on the Wav2Vec2Processor
class instance.
processor.save_pretrained(OUTPUT_DIR_PATH)
The twentieth cell loads each audio file specified in the all_training_samples
list. Set the twentieth cell to:
### CELL 20: Load audio data ###
all_input_data = []
for index in range(len(all_training_samples)):
speech_array, sampling_rate = read_audio_data(all_training_samples[index][0])
all_input_data.append({
"input_values": speech_array,
"labels": all_training_samples[index][1]
})
torch.Tensor
and stored in all_input_data
as a list of dictionaries. Each dictionary contains the audio data for a particular sample, along with the text transcription of the audio.read_audio_data
method returns the sampling rate of the audio data as well. Since we know that the sampling rate is 48000
Hz for all audio files in this use case, the sampling rate is ignored in this step.all_input_data
to a Pandas DataFrameThe twenty-first cell converts the all_input_data
list to a Pandas DataFrame to make it easier to manipulate the data. Set the twenty-first cell to:
### CELL 21: Convert audio training data list to Pandas DataFrame ###
all_input_data_df = pd.DataFrame(data = all_input_data)
The twenty-second cell uses the processor
initialized in Step 3.19 to extract features from each audio data sample and to encode each text transcription as a list of labels. Set the twenty-second cell to:
### CELL 22: Process audio data and text transcriptions ###
all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000))
all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
The twenty-third cell splits the all_input_data_df
DataFrame into training and evaluation (validation) datasets using the SPLIT_PCT
constant from Step 3.5. Set the twenty-third cell to:
### CELL 23: Split input data into training and validation datasets ###
split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT)
valid_data_df = all_input_data_df.iloc[-split:]
train_data_df = all_input_data_df.iloc[:-split]
SPLIT_PCT
value is 0.10
in this guide meaning 10% of all input data will be held out for evaluation and 90% of the data will be used for training/finetuning.Dataset
ObjectsThe twenty-fourth cell converts the train_data_df
and valid_data_df
DataFrames to Dataset
objects. Set the twenty-fourth cell to:
### CELL 24: Convert training and validation datasets to Dataset objects ###
train_data = Dataset.from_pandas(train_data_df)
valid_data = Dataset.from_pandas(valid_data_df)
Dataset
objects are consumed by HuggingFace Trainer
class instances, as you will see in Step 3.30.
You can print train_data
and valid_data
to view the metadata for both Dataset
objects.
print(train_data)
print(valid_data)
# Output:
Dataset({
features: ['input_values', 'labels'],
num_rows: 2880
})
Dataset({
features: ['input_values', 'labels'],
num_rows: 320
})
### CELL 25: Initialize pretrained model ###
model = Wav2Vec2ForCTC.from_pretrained(
MODEL,
ctc_loss_reduction = "mean",
pad_token_id = processor.tokenizer.pad_token_id,
vocab_size = len(processor.tokenizer)
)
from_pretrained
method called on Wav2Vec2ForCTC
specifies that we want to load the pretrained weights for the specified model.MODEL
constant was specified in Step 3.5 and was set to facebook/wav2vec2-xls-r-300m
reflecting the XLS-R (0.3) model.ctc_loss_reduction
parameter specifies the type of reduction to apply to the output of the Connectionist Temporal Classification ("CTC") loss function. CTC loss is used to calculate the loss between a continuous input, in this case audio data, and a target sequence, in this case text transcriptions. By setting the value to mean
, the output losses for a batch of inputs will be divided by the target lengths. The mean over the batch is then calculated and the reduction is applied to loss values.pad_token_id
specifies the token to be used for padding when batching. It is set to the [PAD]
id set when initializing the tokenizer in Step 3.17.vocab_size
parameter defines the vocabulary size of the model. It is the vocabulary size after initialization of the tokenizer in Step 3.17 and reflects the number of output layer nodes of the forward portion of the network.
### CELL 26: Freeze feature extractor ###
model.freeze_feature_extractor()
The twenty-seventh cell initializes the training arguments that will be passed to a Trainer
instance. Set the twenty-seventh cell to:
### CELL 27: Set training arguments ###
training_args = TrainingArguments(
output_dir = OUTPUT_DIR_PATH,
save_safetensors = False,
group_by_length = True,
per_device_train_batch_size = TRAIN_BATCH_SIZE,
per_device_eval_batch_size = EVAL_BATCH_SIZE,
num_train_epochs = TRAIN_EPOCHS,
gradient_checkpointing = True,
evaluation_strategy = "steps",
save_strategy = "steps",
logging_strategy = "steps",
eval_steps = EVAL_STEPS,
save_steps = SAVE_STEPS,
logging_steps = LOGGING_STEPS,
learning_rate = LEARNING_RATE,
warmup_steps = WARMUP_STEPS
)
TrainingArguments
class accepts more than .save_safetensors
parameter when False
specifies that the finetuned model should be saved to a pickle
file instead of using the safetensors
format.group_by_length
parameter when True
indicates that samples of approximately the same length should be grouped together. This minimizes padding and improves training efficiency.per_device_train_batch_size
sets the number of samples per training mini-batch. This parameter is set to 18
via the TRAIN_BATCH_SIZE
constant assigned in Step 3.5. This implies 160 steps per epoch.per_device_eval_batch_size
sets the number of samples per evaluation (holdout) mini-batch. This parameter is set to 10
via the EVAL_BATCH_SIZE
constant assigned in Step 3.5.num_train_epochs
sets the number of training epochs. This parameter is set to 30
via the TRAIN_EPOCHS
constant assigned in Step 3.5. This implies 4,800 total steps during training.gradient_checkpointing
parameter when True
helps to save memory by checkpointing gradient calculations, but results in slower backward passes.evaluation_strategy
parameter when set to steps
means that evaluation will be performed and logged during training at an interval specified by the parameter eval_steps
.logging_strategy
parameter when set to steps
means that training run statistics will be logged at an interval specified by the parameter logging_steps
.save_strategy
parameter when set to steps
means that a checkpoint of the finetuned model will be saved at an interval specified by the parameter save_steps
.eval_steps
sets the number of steps between evaluations of holdout data. This parameter is set to 100
via the EVAL_STEPS
constant assigned in Step 3.5.save_steps
sets the number of steps after which a checkpoint of the finetuned model is saved. This parameter is set to 3200
via the SAVE_STEPS
constant assigned in Step 3.5.logging_steps
sets the number of steps between logs of training run statistics. This parameter is set to 100
via the LOGGING_STEPS
constant assigned in Step 3.5.learning_rate
parameter sets the initial learning rate. This parameter is set to 1e-4
via the LEARNING_RATE
constant assigned in Step 3.5.warmup_steps
parameter sets the number of steps to linearly warmup the learning rate from 0 to the value set by learning_rate
. This parameter is set to 800
via the WARMUP_STEPS
constant assigned in Step 3.5.
### CELL 28: Define data collator logic ###
@dataclass
class DataCollatorCTCWithPadding:
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding = self.padding,
max_length = self.max_length,
pad_to_multiple_of = self.pad_to_multiple_of,
return_tensors = "pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding = self.padding,
max_length = self.max_length_labels,
pad_to_multiple_of = self.pad_to_multiple_of_labels,
return_tensors = "pt",
)
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
Trainer
instance that will be initialized momentarily in Step 3.30. Since the input sequences and label sequences vary in length in each mini-batch, some sequences must be padded so that they are all of the same length.DataCollatorCTCWithPadding
class dynamically pads mini-batch data. The padding
paramenter when set to True
specifies that shorter audio input feature sequences and label sequences should have the same length as the longest sequence in a mini-batch.0.0
set when initializing the feature extractor in Step 3.18.-100
so that these labels are ignored when calculating the WER metric.
### CELL 29: Initialize instance of data collator ###
data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
The thirtieth cell initializes an instance of the Trainer
class. Set the thirtieth cell to:
### CELL 30: Initialize trainer ###
trainer = Trainer(
model = model,
data_collator = data_collator,
args = training_args,
compute_metrics = compute_wer,
train_dataset = train_data,
eval_dataset = valid_data,
tokenizer = processor.feature_extractor
)
Trainer
class is initialized with:
model
initialized in Step 3.25.train_data
Dataset
object from Step 3.24.valid_data
Dataset
object from Step 3.24.tokenizer
parameter is assigned to processor.feature_extractor
and works with data_collator
to automatically pad the inputs to the maximum-length input of each mini-batch.The thirty-first cell calls the train
method on the Trainer
class instance to finetune the model. Set the thirty-first cell to:
### CELL 31: Finetune the model ###
trainer.train()
The thirty-second cell is the last notebook cell. It saves the finetuned model by calling the save_model
method on the Trainer
instance. Set the thirty-second cell to:
### CELL 32: Save the finetuned model ###
trainer.save_model(OUTPUT_DIR_PATH)
Set the Kaggle Notebook to run with the NVIDIA GPU P100 accelerator.
The finetuned model will be output to the Kaggle directory specified by the constant OUTPUT_DIR_PATH
specified in Step 3.5. The model output should include the following files:
pytorch_model.bin
config.json
preprocessor_config.json
vocab.json
training_args.bin
These files can be downloaded locally. Additionally, you can create a new Kaggle Model using the model files. The Kaggle Model will be used with the companion inference guide to run inference on the finetuned model.