Saltar a contenido

Evaluando los modelos

Para evaluar los modelos, necesitamos un conjunto de datos que no hayamos empleado en la fase de entrenamiento y posteriormente, comparar nuestros textos de referencia con la predicción obtenida.

Dataset de validación

Para ello, vamos a coger los audios grabados desde nuestro entrenamiento hasta el día de hoy, y tras ello, de la misma forma que hasta ahora, necesitaremos crear un nuevo dataset que persistiremos en una carpeta diferentes al dataset de entrenamiento.

Por ejemplo, podemos elegir los documentos entre dos fechas de forma similar a:

# Recuperamos todos los audios
from datetime import datetime
cursor = audios_coll.find({"fecha":{"$gte": datetime(2024, 5, 8, 0, 0, 0),"$lte": datetime(2024, 5, 17, 0, 0, 0)}},
                          {"_id":0, "texto.texto":1, "aws_object_id":1})

Validación

Tras ello, ya podemos comparar los resultados de las predicciones:

from datasets import load_from_disk
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from evaluate import load
import torch

# Cargar el modelo y el procesador
model = WhisperForConditionalGeneration.from_pretrained("/home/rutaAlModelo")
processor = WhisperProcessor.from_pretrained("/home/rutaAlModelo")

# Mueve el modelo a la GPU si está disponible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

from datasets import Audio

# Cargamos los datos de evaluación
lara_eval_dataset  = load_from_disk("/home/jupyter-iabd/audioslara/dataset-eval/")
# Recodificamos el SR a 16.000
lara_eval_dataset = lara_eval_dataset.cast_column("audio", Audio(sampling_rate=16000))

forced_decoder_ids = processor.get_decoder_prompt_ids(language="spanish", task="transcribe")

def map_to_pred(batch):
    audio = batch["audio"]
    input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
    batch["reference"] = processor.tokenizer._normalize(batch['texto'])

    with torch.no_grad():
        # Asegurándose de que las características de entrada estén en el mismo dispositivo que el modelo
        input_features = input_features.to(device)
        predicted_ids = model.generate(input_features, max_new_tokens=200, forced_decoder_ids=forced_decoder_ids)[0]

    transcription = processor.decode(predicted_ids)
    batch["prediction"] = processor.tokenizer._normalize(transcription)
    return batch

resultado = lara_eval_dataset.map(map_to_pred)

wer = load("wer")
print( wer.compute(references=resultado["reference"], predictions=resultado["prediction"]))

Si obtenemos 0,0 probablemente haya overfitting ¿verdad?

Resultados en detalle

Una vez obtenidos el WER, podemos mostrar el texto referencia, la predicción y el valor del WER:

for i in range(len(resultado)):
    prediction = resultado[i]['prediction']
    audio_name = resultado[i]['reference']  # o el nombre de tu audio si está disponible
    wer_score = wer.compute(references=[resultado[i]["reference"]], predictions=[resultado[i]["prediction"]])
    print(f'[texto]: {audio_name} \t// [prediccion]: {prediction} \t // WER: {wer_score}')

Callbacks

Hasta ahora, cada vez que se crea un checkpoint, se persiste un modelo en el servidor, lo que implica que llenemos el almacenamiento del mismo.

Para evitarlo, podemos crear un callback y que se invocará cada vez que vaya al almacenar un modelo (on_save):

# Anyadimos los callbacks
from transformers import TrainerCallback

class CheckpointCleanupCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_dirs = [os.path.join(args.output_dir, name) for name in os.listdir(args.output_dir) if name.startswith("checkpoint-")]
        checkpoint_dirs.sort(key=os.path.getmtime, reverse=True)  
        if len(checkpoint_dirs) > 1: 
            for old_checkpoint_dir in checkpoint_dirs[1:]:
                shutil.rmtree(old_checkpoint_dir)

A la hora de crear el trainer, le pasaremos los callbacks:

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    ...
    callbacks=[CheckpointCleanupCallback()], 
)

Tareas a realizar

  • Realizar la validación de los diferentes modelos entrenados en la sesiones anteriores.
  • Ajustar los modelos con diferentes parámetros y comprobar si mejoran los resultados.
  • Emplear los callbacks para quedarse sólo con un checkpoint.

Plazo de entrega

  • Viernes 24 Mayo - 17:00 - Modelos re-entrenados y evaluados, gestionando los checkpoints mediante callbacks