Abrir en Google Colab Descargar notebook

Modelo de clasificación de imágenes usando ViT

Introducción

Los modelos basados en transformers para visión han demostrado ser ampliamente potentes en régimenes donde la cantidad de datos es suficiente, a pesar de no tener sezgo inductivo presente en su arquitectura.

En este ejemplo veremos cómo podemos utilizar un modelo de este tipo para resolver un problema de clasificación de imágenes.

Preparación del ambiente

Intalamos las librerias necesarias

[1]:
!pip install -U transformers datasets fsspec --quiet

Sobre el conjunto de datos de este ejemplo

Para ejemplificar esta técnica utilizaremos un conjunto de datos muy popular llamado CIFAR-10. CIFAR-10 es un conjunto de datos que consiste en 60.000 imagenes a color de 32x32 agrupadas en 10 clases, con 6000 imagenes cada una. Hay alrededor de 50000 imagenes para entrenamiento y 10000 para testing.

Las categorias son:

  • airplane

  • automobile

  • bird

  • cat

  • deer

  • dog

  • frog

  • horse

  • ship

  • truck

Podemos cargar este conjunto de datos facilmente utilizando datasets:

[2]:
from datasets import load_dataset

train_ds, test_ds = load_dataset('cifar10', split=['train', 'test'])
/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(

Divieremos el conjunto de datos de entrenamiento en 2 subconjuntos: uno de entrenamiento (90%) y otro de validación (10%):

[3]:
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

Veamos algunos detalles del conjunto de datos:

[4]:
train_ds
[4]:
Dataset({
    features: ['img', 'label'],
    num_rows: 45000
})
[5]:
train_ds.features
[5]:
{'img': Image(mode=None, decode=True, id=None),
 'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None)}

Veamos algunos ejemplos:

[6]:
train_ds[0]['img'].resize((224, 224))
[6]:
../../../../_images/vision_neural_vit_code_vit_classification_16_0.png

Construcción de un modelo de clasificación utilizando ViT

Antes de comenzar necesitaremos verificar que tenemos el runtime correcto en nuestro ambiente. Esta tarea se beneficiará mucho de una GPU.

[7]:
import torch
print("GPUs disponibles: ", torch.cuda.device_count())
GPUs disponibles:  1
[8]:
model_name = "google/vit-base-patch16-224-in21k"

Preprocesamiento

Como es común, deberemos de ocuparnos del preprocesamiento de nuestras imágenes:

Labels

Debemos de disponer de un mapeo entre las etiquetas y sus valores:

[9]:
label2id = { label:id for id, label in enumerate(test_ds.features["label"].names) }
id2label = { id:label for id, label in enumerate(test_ds.features["label"].names) }

Tamaño y cantidad de parches

Los transformers esperan recibir cada imagen con una resolución específica y una cantidad de parches particular. Esta tarea se puede realizar utilizando el concepto de processor en transformers.

Este tipo de components preprocesan los datos antes de ser enviados al modelo que los procesará. En este caso, utilizaremos el processor associado con un modelo en particular, google/vit-base-patch16-224-in21k. Notar como en el nombre está codificado el tamaño de los parches (16x16) y la resolución de las imágenes (224x224).

[10]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained(model_name)

Este componente también realizará normalización y estandarización, dos operaciones que típicamente realizamos en visión por computadora.

Generaremos una función que toma un ejemplo del conjunto de datos y aplica nuestro processor de forma que devuelve los pixeles preprocesados y su etiqueta correspondiente:

[11]:
def process_example(example):
    inputs = processor(example['img'], return_tensors='pt')
    inputs['labels'] = example['label']

    return inputs

Veamos como funciona la función con un ejemplo:

[12]:
example = process_example(train_ds[0])
example
[12]:
{'pixel_values': tensor([[[[-0.3961, -0.3961, -0.3961,  ..., -0.2549, -0.2549, -0.2549],
          [-0.3961, -0.3961, -0.3961,  ..., -0.2549, -0.2549, -0.2549],
          [-0.3961, -0.3961, -0.3961,  ..., -0.2549, -0.2549, -0.2549],
          ...,
          [ 0.2471,  0.2471,  0.2471,  ..., -0.8353, -0.8353, -0.8353],
          [ 0.2471,  0.2471,  0.2471,  ..., -0.8353, -0.8353, -0.8353],
          [ 0.2471,  0.2471,  0.2471,  ..., -0.8353, -0.8353, -0.8353]],

         [[-0.8196, -0.8196, -0.8196,  ..., -0.6627, -0.6627, -0.6627],
          [-0.8196, -0.8196, -0.8196,  ..., -0.6627, -0.6627, -0.6627],
          [-0.8196, -0.8196, -0.8196,  ..., -0.6627, -0.6627, -0.6627],
          ...,
          [ 0.1059,  0.1059,  0.1059,  ..., -0.8824, -0.8824, -0.8824],
          [ 0.1059,  0.1059,  0.1059,  ..., -0.8824, -0.8824, -0.8824],
          [ 0.1059,  0.1059,  0.1059,  ..., -0.8824, -0.8824, -0.8824]],

         [[-0.9137, -0.9137, -0.9137,  ..., -0.8510, -0.8510, -0.8510],
          [-0.9137, -0.9137, -0.9137,  ..., -0.8510, -0.8510, -0.8510],
          [-0.9137, -0.9137, -0.9137,  ..., -0.8510, -0.8510, -0.8510],
          ...,
          [ 0.1137,  0.1137,  0.1137,  ..., -0.8667, -0.8667, -0.8667],
          [ 0.1137,  0.1137,  0.1137,  ..., -0.8667, -0.8667, -0.8667],
          [ 0.1137,  0.1137,  0.1137,  ..., -0.8667, -0.8667, -0.8667]]]]), 'labels': 3}

Es útil ver las dimensiones de nuestros datos:

[13]:
example["pixel_values"].shape
[13]:
torch.Size([1, 3, 224, 224])

Aplicando las transformaciones al conjunto de datos

Apliquemos las transformaciones a todas las instancias:

[14]:
prepared_train_ds = train_ds.with_transform(process_example)
prepared_val_ds = val_ds.with_transform(process_example)

Modelado

Utilizaremos el modelo ViTForImageClassification, el cual agrega una capa de clasificación (torch.nn.Linear) sobre el token [CLS] que se agrega al final de la secuencia de parches en el modelo preentrenado de ViTModel. El último hidden state de este token puede ser considerado como una repreentación de la imágen completa. Configuraremos el número de clases (units en la última capa):

[16]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, top_k=1, id2label=id2label, label2id=label2id).to('cuda')
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Entrenamiento

Configuraremos nuestros parámeteros de entrenamiento de la siguiente forma:

[38]:
from transformers import TrainingArguments, Trainer
from transformers import default_data_collator

args = TrainingArguments(
    "test-cifar-10",
    save_strategy="epoch",
    eval_strategy="epoch",
    num_train_epochs=3,
    optim="adamw_torch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    load_best_model_at_end=True,
    report_to="none",
    remove_unused_columns=False,
)

Configuramos el proceso de entrenamiento del modelo:

[39]:
trainer = Trainer(
    model,
    args,
    data_collator=default_data_collator,
    train_dataset=prepared_train_ds,
    eval_dataset=prepared_val_ds,
)
[40]:
history = trainer.train()
[ 966/2112 36:39 < 43:35, 0.44 it/s, Epoch 1.37/3]
Epoch Training Loss Validation Loss
1 0.370300 0.149960

/usr/local/lib/python3.11/dist-packages/transformers/configuration_utils.py:394: UserWarning: Some non-default generation parameters are set in the model config. These should go into either a) `model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model).This warning will become an exception in the future.
Non-default generation parameters: {'top_k': 1}
  warnings.warn(
[2112/2112 1:21:26, Epoch 3/3]
Epoch Training Loss Validation Loss
1 0.370300 0.149960
2 0.129400 0.092566
3 0.047000 0.085288

/usr/local/lib/python3.11/dist-packages/transformers/configuration_utils.py:394: UserWarning: Some non-default generation parameters are set in the model config. These should go into either a) `model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model).This warning will become an exception in the future.
Non-default generation parameters: {'top_k': 1}
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/transformers/configuration_utils.py:394: UserWarning: Some non-default generation parameters are set in the model config. These should go into either a) `model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model).This warning will become an exception in the future.
Non-default generation parameters: {'top_k': 1}
  warnings.warn(

Evaluación

Una vez que el modelo esta entrenado, veamos como se comporta en el conjunto de evaluación:

[41]:
prepared_test_ds = test_ds.with_transform(process_example)
[42]:
predictions = trainer.predict(test_dataset=prepared_test_ds).predictions

Obtenemos las clases asociadas:

[43]:
import numpy as np

predictions = np.argmax(predictions, axis=1)

Creamos un reporte de clasificación:

[44]:
from sklearn.metrics import classification_report

print(classification_report(test_ds['label'], predictions, target_names=test_ds.features["label"].names))
              precision    recall  f1-score   support

    airplane       0.99      0.98      0.99      1000
  automobile       0.98      0.98      0.98      1000
        bird       0.99      0.99      0.99      1000
         cat       0.96      0.97      0.96      1000
        deer       0.99      0.99      0.99      1000
         dog       0.98      0.95      0.97      1000
        frog       0.99      1.00      1.00      1000
       horse       0.99      0.99      0.99      1000
        ship       0.99      0.99      0.99      1000
       truck       0.98      0.98      0.98      1000

    accuracy                           0.98     10000
   macro avg       0.98      0.98      0.98     10000
weighted avg       0.98      0.98      0.98     10000

Guardar el modelo para su posterior utilización

Podemos guardar los componentes de este modelo para luego construir un pipeline con el mismo:

[45]:
model_name = "vit_cifar10"

trainer.save_model(model_name)
processor.save_pretrained(model_name)
/usr/local/lib/python3.11/dist-packages/transformers/configuration_utils.py:394: UserWarning: Some non-default generation parameters are set in the model config. These should go into either a) `model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model).This warning will become an exception in the future.
Non-default generation parameters: {'top_k': 1}
  warnings.warn(
[45]:
['vit_cifar10/preprocessor_config.json']
[46]:
from transformers import pipeline

classifier = pipeline("image-classification", model="vit_cifar10", image_processor="vit_cifar10", top_k=1, device=0)
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0
[47]:
classifier(train_ds["img"][:3])
[47]:
[[{'label': 'cat', 'score': 0.9640224575996399}],
 [{'label': 'ship', 'score': 0.9606396555900574}],
 [{'label': 'ship', 'score': 0.9603068828582764}]]