Thursday, July 18, 2024
No menu items!
HomeArtificial Intelligence and Machine LearningFine-tune Whisper models on Amazon SageMaker with LoRA

Fine-tune Whisper models on Amazon SageMaker with LoRA

Whisper is an Automatic Speech Recognition (ASR) model that has been trained using 680,000 hours of supervised data from the web, encompassing a range of languages and tasks. One of its limitations is the low-performance on low-resource languages such as Marathi language and Dravidian languages, which can be remediated with fine-tuning. However, fine-tuning a Whisper model has become a considerable challenge, both in terms of computational resources and storage requirements. Five to ten runs of full fine-tuning for Whisper models demands approximately 100 hours A100 GPU (40 GB SXM4) (varies based on model sizes and model parameters), and each fine-tuned checkpoint necessitates about 7 GB of storage space. This combination of high computational and storage demands can pose significant hurdles, especially in environments with limited resources, often making it exceptionally difficult to achieve meaningful results.

Low-Rank Adaptation, also known as LoRA, takes a unique approach to model fine-tuning. It maintains the pre-trained model weights in a static state and introduces trainable rank decomposition matrices into each layer of the Transformer structure. This method can decrease the number of trainable parameters needed for downstream tasks by 10,000 times and reduce GPU memory requirement by 3 times. In terms of model quality, LoRA has been shown to match or even exceed the performance of traditional fine-tuning methods, despite operating with fewer trainable parameters (see the results from the original LoRA paper). It also offers the benefit of increased training throughput. Unlike the adapter methods, LoRA doesn’t introduce additional latency during inference, thereby maintaining the efficiency of the model during the deployment phase. Fine-tuning Whisper using LoRA has shown promising results. Take Whisper-Large-v2, for instance: running 3-epochs with a 12-hour common voice dataset on 8 GB memory GPU takes 6–8 hours, which is 5 times faster than full fine-tuning with comparable performance.

Amazon SageMaker is an ideal platform to implement LoRA fine-tuning of Whisper. Amazon SageMaker enables you to build, train, and deploy machine learning models for any use case with fully managed infrastructure, tools, and workflows. Additional model training benefits can include lower training costs with Managed Spot Training, distributed training libraries to split models and training datasets across AWS GPU instances, and more.  The trained SageMaker models can be easily deployed for inference directly on SageMaker. In this post, we present a step-by-step guide to implement LoRA fine-tuning in SageMaker. The source code associated with this implementation can be found on GitHub.

Prepare the dataset for fine-tuning

We use the low-resource language Marathi for the fine-tuning task. Using the Hugging Face datasets library, you can download and split the Common Voice dataset into training and testing datasets. See the following code:

from datasets import load_dataset, DatasetDict

language = “Marathi”
language_abbr = “mr”
task = “transcribe”
dataset_name = “mozilla-foundation/common_voice_11_0”

common_voice = DatasetDict()
common_voice[“train”] = load_dataset(dataset_name, language_abbr, split=”train+validation”, use_auth_token=True)
common_voice[“test”] = load_dataset(dataset_name, language_abbr, split=”test”, use_auth_token=True)

The Whisper speech recognition model requires audio inputs to be 16kHz mono 16-bit signed integer WAV files. Because the Common Voice dataset is 48K sampling rate, you will need to downsample the audio files first. Then you need to apply Whisper’s feature extractor to the audio to extract log-mel spectrogram features, and apply Whisper’s tokenizer to the framed features to convert each sentence in the transcript into a token ID. See the following code:

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)

def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch[“audio”]

# compute log-Mel input features from input audio array
batch[“input_features”] = feature_extractor(audio[“array”], sampling_rate=audio[“sampling_rate”]).input_features[0]

# encode target text to label ids
batch[“labels”] = tokenizer(batch[“sentence”]).input_ids
return batch

#apply the data preparation function to all of our fine-tuning dataset samples using dataset’s .map method.
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names[“train”], num_proc=2)
common_voice.save_to_disk(“marathi-common-voice-processed”)
!aws s3 cp –recursive “marathi-common-voice-processed” s3://<Your-S3-Bucket>

After you have processed all the training samples, upload the processed data to Amazon S3, so that when using the processed training data in the fine-tuning stage, you can use FastFile to mount the S3 file directly instead of copying it to local disk:

from sagemaker.inputs import TrainingInput
training_input_path=s3uri
training = TrainingInput(
s3_data_type=’S3Prefix’, # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
s3_data=training_input_path,
distribution=’FullyReplicated’, # Available Options: FullyReplicated | ShardedByS3Key
input_mode=’FastFile’
)

Train the model

For demonstration, we use whisper-large-v2 as the pre-trained model (whisper v3 is now available), which can be imported through Hugging Face transformers library. You can use 8-bit quantization to further improve training efficiency. 8-bit quantization offers the memory optimization by rounding from floating point to 8-bit integers. It is a commonly used model compression technique to get the savings of reduced memory without sacrificing precision during inference too much.

To load the pre-trained model in 8-bit quantized format, we simply add the load_in_8bit=True argument when instantiating the model, as shown in the following code. This will load the model weights quantized to 8 bits, reducing the memory footprint.

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map=”auto”)

We use the LoRA implementation from Hugging Face’s peft package. There are four steps to fine-tune a model using LoRA:

Instantiate a base model (as we did in the last step).
Create a configuration (LoraConfig) where LoRA-specific parameters are defined.
Wrap the base model with get_peft_model() to get a trainable PeftModel.
Train the PeftModel as the base model.

See the following code:

from peft import LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=64, target_modules=[“q_proj”, “v_proj”], lora_dropout=0.05, bias=”none”)
model = get_peft_model(model, config)

training_args = Seq2SeqTrainingArguments(
output_dir=args.model_dir,
per_device_train_batch_size=int(args.train_batch_size),
gradient_accumulation_steps=1,
learning_rate=float(args.learning_rate),
warmup_steps=args.warmup_steps,
num_train_epochs=args.num_train_epochs,
evaluation_strategy=”epoch”,
fp16=True,
per_device_eval_batch_size=args.eval_batch_size,
generation_max_length=128,
logging_steps=25,
remove_unused_columns=False,
label_names=[“labels”],
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset[“train”],
eval_dataset=train_dataset.get(“test”, train_dataset[“test”]),
data_collator=data_collator,
tokenizer=processor.feature_extractor,
)

To run a SageMaker training job, we bring our own Docker container. You can download the Docker image from GitHub, where ffmpeg4 and git-lfs are packaged together with other Python requirements. To learn more about how to adapt your own Docker container to work with SageMaker, refer to Adapting your own training container. Then you can use the Hugging Face Estimator and start a SageMaker training job:

OUTPUT_PATH= f’s3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/’

huggingface_estimator = HuggingFace(entry_point=’train.sh’,
source_dir=’./src’,
output_path= OUTPUT_PATH,
instance_type=instance_type,
instance_count=1,
# transformers_version=’4.17.0′,
# pytorch_version=’1.10.2′,
py_version=’py310′,
image_uri=<ECR-PATH>,
role=ROLE,
metric_definitions = metric_definitions,
volume_size=200,
distribution=distribution,
keep_alive_period_in_seconds=1800,
environment=environment,
)

huggingface_estimator.fit(job_name=TRAINING_JOB_NAME, wait=False)

The implementation of LoRA enabled us to run the Whisper large fine-tuning task on a single GPU instance (for example, ml.g5.2xlarge). In comparison, the Whisper large full fine-tuning task requires multiple GPUs (for example, ml.p4d.24xlarge) and a much longer training time. More specifically, our experiment demonstrated that the full fine-tuning task requires 24 times more GPU hours compared to the LoRA approach.

Evaluate model performance

To evaluate the performance of the fine-tuned Whisper model, we calculate the word error rate (WER) on a held-out test set. WER measures the difference between the predicted transcript and the ground truth transcript. A lower WER indicates better performance. You can run the following script against the pre-trained model and fine-tuned model and compare their WER difference:

metric = evaluate.load(“wer”)

eval_dataloader = DataLoader(common_voice[“test”], batch_size=8, collate_fn=data_collator)

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
with torch.cuda.amp.autocast():
with torch.no_grad():
generated_tokens = (
model.generate(
input_features=batch[“input_features”].to(“cuda”),
decoder_input_ids=batch[“labels”][:, :4].to(“cuda”),
max_new_tokens=255,
)
.cpu()
.numpy()
)
labels = batch[“labels”].cpu().numpy()
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
del generated_tokens, labels, batch
gc.collect()
wer = 100 * metric.compute()
print(f”{wer=}”)

Conclusion

In this post, we demonstrated fine-tuning Whisper, a state-of-the-art speech recognition model. In particular, we used Hugging Face’s PEFT LoRA and enabled 8-bit quantization for efficient training. We also demonstrated how to run the training job on SageMaker.

Although this is an important first step, there are several ways you can build on this work to further improve the whisper model. Going forward, consider using SageMaker distributed training to scale training on a much larger dataset. This will allow the model to train on more varied and comprehensive data, improving accuracy. You can also optimize latency when serving the Whisper model, to enable real-time speech recognition. Additionally, you could expand work to handle longer audio transcriptions, which requires changes to model architecture and training schemes.

Acknowledgement

The authors extend their gratitude to Paras Mehra, John Sol and Evandro Franco for their insightful feedback and review of the post.

About the Authors

Jun Shi is a Senior Solutions Architect at Amazon Web Services (AWS). His current areas of focus are AI/ML infrastructure and applications. He has over a decade experience in the FinTech industry as software engineer.

Dr. Changsha Ma is an AI/ML Specialist at AWS. She is a technologist with a PhD in Computer Science, a master’s degree in Education Psychology, and years of experience in data science and independent consulting in AI/ML. She is passionate about researching methodological approaches for machine and human intelligence. Outside of work, she loves hiking, cooking, hunting food, and spending time with friends and families.

Read MoreAWS Machine Learning Blog

RELATED ARTICLES

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Most Popular

Recent Comments