visit
This is a companion guide to Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition (the "Part 1 guide"). I wrote the Part 1 guide on how to finetune Meta AI's model on Chilean Spanish. It is assumed that you have completed that guide and have generated your own finetuned XLS-R model. This guide will explain the steps to run inference on your finetuned XLS-R model via a Kaggle Notebook.
spanish-asr-inference
.This guide uses the as its source for test data. Like the , the Peruvian speakers dataset also consists of two sub-datasets: 2,918 recordings of male Peruvian speakers and 2,529 recordings of female Peruvian speakers.
Add both of these datasets to your Kaggle Notebook by clicking on Add Input.
You should have saved your finetuned model in Step 4 of the Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition guide as a Kaggle Model.
Add your finetuned model to your Kaggle Notebook by clicking on Add Input.
### CELL 1: Install Packages ###
!pip install --upgrade torchaudio
!pip install jiwer
### CELL 2: Import Python packages ###
import re
import math
import random
import pandas as pd
import torchaudio
from datasets import load_metric
from transformers import pipeline
### CELL 3: Load WER metric ###
wer_metric = load_metric("wer")
### CELL 4: Constants ###
# Testing data
TEST_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-peru-male/"
TEST_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-peru-female/"
EXT = ".wav"
NUM_LOAD_FROM_EACH_SET = 3
# Special characters
SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\‘\’\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]"
# Sampling rates
ORIG_SAMPLING_RATE = 48000
TGT_SAMPLING_RATE = 16000
### CELL 5: Utility methods for reading index files, cleaning text, random indices generator ###
def read_index_file_data(path: str, filename: str):
data = []
with open(path + filename, "r", encoding = "utf8") as f:
lines = f.readlines()
for line in lines:
file_and_text = line.split("\t")
data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")])
return data
def clean_text(text: str) -> str:
cleaned_text = re.sub(SPECIAL_CHARS, "", text)
cleaned_text = cleaned_text.lower()
return cleaned_text
def get_random_samples(dataset: list, num: int) -> list:
used = []
samples = []
for i in range(num):
a = -1
while a == -1 or a in used:
a = math.floor(len(dataset) * random.random())
samples.append(dataset[a])
used.append(a)
return samples
The read_index_file_data
method reads a line_index.tsv
dataset index file and produces a list of lists with audio filename and transcription data, e.g.:
[
["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"]
...
]
clean_text
method is used to strip each text transcription of the characters specified by the regular expression assigned to SPECIAL_CHARS
in Step 2.4. These characters, inclusive of punctuation, can be eliminated as they don't provide any semantic value when training the model to learn mappings between audio features and text transcriptions.get_random_samples
method returns a set of random test samples with the quantity set by the constant NUM_LOAD_FROM_EACH_SET
in Step 2.4.The sixth cell defines utility methods using torchaudio
to load and resample audio data. Set the sixth cell to:
### CELL 7: Utility methods for loading and resampling audio data ###
def read_audio_data(file):
speech_array, sampling_rate = torchaudio.load(file, normalize = True)
return speech_array, sampling_rate
def resample(waveform):
transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE)
waveform = transform(waveform)
return waveform[0]
read_audio_data
method loads a specified audio file and returns a torch.Tensor
multi-dimensional matrix of the audio data along with the sampling rate of the audio. All the audio files in the training data have a sampling rate of 48000
Hz. This "original" sampling rate is captured by the constant ORIG_SAMPLING_RATE
in Step 2.4.resample
method is used to downsample audio data from a sampling rate of 48000
to the target sampling rate of 16000
.The seventh cell reads the test data index files for the recordings of male speakers and the recordings of female speakers using the read_index_file_data
method defined in Step 2.5. Set the seventh cell to:
### CELL 7: Read test data ###
test_data_male = read_index_file_data(TEST_DATA_PATH_MALE, "line_index.tsv")
test_data_female = read_index_file_data(TEST_DATA_PATH_FEMALE, "line_index.tsv")
The eighth cell generates sets of random test samples using the get_random_samples
method defined in Step 2.5. Set the eighth cell to:
### CELL 8: Generate lists of random test samples ###
random_test_samples_male = get_random_samples(test_data_male, NUM_LOAD_FROM_EACH_SET)
random_test_samples_female = get_random_samples(test_data_female, NUM_LOAD_FROM_EACH_SET)
### CELL 9: Combine test data ###
all_test_samples = random_test_samples_male + random_test_samples_female
The tenth cell iterates over each test data sample and cleans the associated transcription text using the clean_text
method defined in Step 2.5. Set the tenth cell to:
### CELL 10: Clean text transcriptions ###
for index in range(len(all_test_samples)):
all_test_samples[index][1] = clean_text(all_test_samples[index][1])
The eleventh cell loads each audio file specified in the all_test_samples
list. Set the eleventh cell to:
### CELL 11: Load audio data ###
all_test_data = []
for index in range(len(all_test_samples)):
speech_array, sampling_rate = read_audio_data(all_test_samples[index][0])
all_test_data.append({
"raw": speech_array,
"sampling_rate": sampling_rate,
"target_text": all_test_samples[index][1]
})
torch.Tensor
and stored in all_test_data
as a list of dictionaries. Each dictionary contains the audio data for a particular sample, the sampling rate, and the text transcription of the audio.The twelfth cell resamples audio data to the target sampling rate of 16000
. Set the twelfth cell to:
### CELL 12: Resample audio data and cast to NumPy arrays ###
all_test_data = [{"raw": resample(sample["raw"]).numpy(), "sampling_rate": TGT_SAMPLING_RATE, "target_text": sample["target_text"]} for sample in all_test_data]
The thirteenth cell initializes an instance of the HuggingFace transformer
library pipeline
class. Set the thirteenth cell to:
### CELL 13: Initialize instance of Automatic Speech Recognition Pipeline ###
transcriber = pipeline("automatic-speech-recognition", model = "YOUR_FINETUNED_MODEL_PATH")
The model
parameter must be set to the path to your finetuned model added to the Kaggle Notebook in Step 1.3, e.g.:
transcriber = pipeline("automatic-speech-recognition", model = "/kaggle/input/xls-r-300m-chilean-spanish/transformers/hardy-pine/1")
The fourteenth cell calls the transcriber
initialized in the previous step on the test data to generate text predictions. Set the fourteenth cell to:
### CELL 14: Generate transcriptions ###
transcriptions = transcriber(all_test_data)
### CELL 15: Calculate WER metrics ###
predictions = [transcription["text"] for transcription in transcriptions]
references = [transcription["target_text"][0] for transcription in transcriptions]
wers = []
for p in range(len(predictions)):
wer = wer_metric.compute(predictions = [predictions[p]], references = [references[p]])
wers.append(wer)
zipped = list(zip(predictions, references, wers))
df = pd.DataFrame(zipped, columns=["Prediction", "Reference", "WER"])
wer = wer_metric.compute(predictions = predictions, references = references)
### CELL 16: Output WER metrics ###
pd.set_option("display.max_colwidth", None)
print(f"Overall WER: {wer}")
print(df)
Since the notebook generates predictions on random samples of test data, the output will vary each time the notebook is run. The following output was generated on a run of the notebook with NUM_LOAD_FROM_EACH_SET
set to 3
for a total of 6 test samples:
Overall WER: 0.0888888
Prediction \
0 quiero que me reserves el mejor asiento del teatro
1 el llano en llamas es un clásico de juan rulfo
2 el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera
3 hay tres cafés que están abiertos hasta las once de la noche
4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras
5 cuántos albergues se abrieron después del terremoto del diecinueve de setiembre
Reference \
0 quiero que me reserves el mejor asiento del teatro
1 el llano en llamas es un clásico de juan rulfo
2 el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera
3 hay tres cafés que están abiertos hasta las once de la noche
4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras
5 cuántos albergues se abrieron después del terremoto del diecinueve de septiembre
WER
0 0.000000
1 0.000000
2 0.000000
3 0.000000
4 0.000000
5 0.090909
As can be seen, the model did an excellent job! It only made one error with the sixth sample (index 5
), mis-spelling the word septiembre
as setiembre
. Of course, running the notebook again with different test samples and, more importantly, a larger number of test samples, will produce different and more informative results. Nonetheless, this limited data suggests the model can perform well on different dialects of Spanish - i.e. it was trained on Chilean Spanish, but appears to perform well on Peruvian Spanish.
If you are just learning how to work with wav2vec2 models, I hope that the Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition guide and this guide were useful for you. As mentioned, the finetuned model generated by the Part 1 guide is not quite state-of-the-art, but should still prove useful for many applications. Happy building!