Abrir en Google Colab
|
Descargar notebook
|
Salency maps para NLP utilizando AllenNLP
PRECAUCIÓN 😱: El tema presentado en esta sección está clasificado como avanzado. El entendimiento de este contenido es totalmente opcional.
Introducción
AllenNLP es un framework general de aprendizaje profundo para NLP, establecido por el mundialmente famoso Allen Institute for AI Lab. Contiene modelos de referencia de última generación que se ejecutan sobre el PyTorch. AllenNLP es una librería que ademas busca implementar abstracciones que permitan el rápido desarrollo de modelos y reutilización de componentes al despegarse de detalles de implementación de cada modelo.
En este ejemplo, veremos como utilizar esta librería para generar salency maps utilizando los gradientes de las prediciones. Esto nos permita interpretar las predicciones de nuestros modelos basados en transformers.
Para ejecutar este notebook
Para ejecutar este notebook, instale las siguientes librerias:
[2]:
!wget https://raw.githubusercontent.com/santiagxf/M72109/master/NLP/Datasets/mascorpus/tweets_marketing.csv \
--quiet --no-clobber --directory-prefix ./Datasets/mascorpus/
!wget https://raw.githubusercontent.com/santiagxf/M72109/master/m72109/nlp/explanation.py \
--quiet --no-clobber --directory-prefix ./m72109/nlp/
!wget https://raw.githubusercontent.com/santiagxf/M72109/master/docs/nlp/neural/allennlp_interpret.txt \
--quiet --no-clobber
%pip install -r allennlp_interpret.txt --quiet
|████████████████████████████████| 3.1 MB 4.4 MB/s
|████████████████████████████████| 831.4 MB 2.6 kB/s
|████████████████████████████████| 719 kB 57.2 MB/s
|████████████████████████████████| 596 kB 53.4 MB/s
|████████████████████████████████| 880 kB 52.3 MB/s
|████████████████████████████████| 3.3 MB 39.5 MB/s
|████████████████████████████████| 86 kB 4.3 MB/s
|████████████████████████████████| 125 kB 73.3 MB/s
|████████████████████████████████| 1.8 MB 44.9 MB/s
|████████████████████████████████| 592 kB 57.0 MB/s
|████████████████████████████████| 248 kB 63.0 MB/s
Installing build dependencies ... done
Getting requirements to build wheel ... done
Installing backend dependencies ... done
Preparing wheel metadata ... done
|████████████████████████████████| 1.2 MB 47.7 MB/s
|████████████████████████████████| 132 kB 64.7 MB/s
|████████████████████████████████| 77 kB 6.5 MB/s
|████████████████████████████████| 8.9 MB 54.4 MB/s
|████████████████████████████████| 79 kB 9.1 MB/s
|████████████████████████████████| 138 kB 65.9 MB/s
|████████████████████████████████| 127 kB 67.1 MB/s
|████████████████████████████████| 21.0 MB 1.2 MB/s
|████████████████████████████████| 23.2 MB 1.3 MB/s
|████████████████████████████████| 23.3 MB 84.3 MB/s
|████████████████████████████████| 23.3 MB 1.2 MB/s
|████████████████████████████████| 22.1 MB 70.1 MB/s
|████████████████████████████████| 22.1 MB 80.7 MB/s
|████████████████████████████████| 181 kB 63.7 MB/s
|████████████████████████████████| 145 kB 67.0 MB/s
|████████████████████████████████| 63 kB 1.6 MB/s
Building wheel for fairscale (PEP 517) ... done
Building wheel for jsonnet (setup.py) ... done
Building wheel for pathtools (setup.py) ... done
Building wheel for sacremoses (setup.py) ... done
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchtext 0.12.0 requires torch==1.11.0, but you have torch 1.9.0 which is incompatible.
torchaudio 0.11.0+cu113 requires torch==1.11.0, but you have torch 1.9.0 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
Si ejecuta en Google Colab, adicionalmente deberá cambiar la version de la libraria google-cloud-storage:
[3]:
%pip install -U google-cloud-storage==1.40.0 --quiet
|████████████████████████████████| 104 kB 2.6 MB/s
|████████████████████████████████| 75 kB 5.0 MB/s
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-cloud-bigquery 1.21.0 requires google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1, but you have google-resumable-media 1.3.3 which is incompatible.
Descargaremos un modelo previamente entrenando el el problema de clasificación de Tweets:
[4]:
!wget https://santiagxf.blob.core.windows.net/public/models/tweet_classification_bert.zip --no-clobber --quiet
!unzip -qq tweet_classification_bert.zip
Cargamos el set de datos
[5]:
import pandas as pd
tweets = pd.read_csv('Datasets/mascorpus/tweets_marketing.csv')
Cargando un modelo entreando con Transformers en AllenNLP
allennlp es un framework compatible con la libraría transformers lo cual resulta atractivo a la hora de utilizar modelos que son entrenados en una para luego llevarlo a la otra. Veamos entonces como podemos hacer para cargar el modelo que tenemos previamente entrenado para la clasificación de tweets utilizando una arquitectura BERT dentro de este framework. En particular, nuestro modelo se persistió en el directorio «tweet_classification».
[6]:
model_name = "tweet_classification_bert"
El detalle de como utilizar AllenNLP esta fuera de este curso, pero utilizaremos el format JSON para cargar modelos de esta libreria. El siguiente codigo carga un modelo exactamente igual al que creamos utilizando la libraria HuggingFace anteriormente.
[7]:
from allennlp.common import Params
from allennlp.data.dataset_readers import DatasetReader
params = Params({
"type": "text_classification_json",
"tokenizer": {
"type": "pretrained_transformer",
"model_name": model_name,
},
"token_indexers": {
"tokens": {
"type": "pretrained_transformer",
"model_name": model_name,
}
}
})
dataset_reader = DatasetReader.from_params(params)
[15]:
from allennlp.common import Params
from allennlp.models import Model
from transformers import AutoModelForSequenceClassification
params = Params({
"type": "basic_classifier",
"vocab": {
"type": "from_pretrained_transformer",
"model_name": model_name,
},
"text_field_embedder": {
"type": "basic",
"token_embedders": {
"tokens": {
"type": "pretrained_transformer",
"model_name": model_name
}
}
},
"seq2vec_encoder": {
"type": "bert_pooler",
"pretrained_model": model_name
},
"dropout": 0.1,
"num_labels": 5,
});
model = Model.from_params(params)
model._classification_layer.weight = AutoModelForSequenceClassification.from_pretrained(model_name).classifier.weight
model._classification_layer.bias = AutoModelForSequenceClassification.from_pretrained(model_name).classifier.bias
_ = model.eval()
[19]:
from allennlp.predictors import TextClassifierPredictor
predictor = TextClassifierPredictor(model, dataset_reader)
Recordemos que en el conjunto de datos de entrenamiento, las etiquetas se distribuyen como sigue:
[20]:
labels = [
'ALIMENTACION',
'AUTOMOCION',
'BANCA',
'BEBIDAS',
'DEPORTES',
'RETAIL',
'TELCO'
]
Interpretando nuestras predicciones
Una vez que tenemos nuestro modelo correctamente cargado, veamos como podemos interpretar una predicción computando el salency map a partir de los gradientes.
[21]:
from allennlp.interpret.saliency_interpreters import SimpleGradient, IntegratedGradient, SmoothGradient
interpreter = SmoothGradient(predictor)
Busquemos un tweet para interpretar:
[22]:
sample_text_idx = 1522
sample_text = tweets['TEXTO'][sample_text_idx]
sample_label = tweets['SECTOR'][sample_text_idx]
print("Texto:", sample_text, "\Sector:", sample_label)
Texto: @HyundaiPeru con Grupo Primax realiza este verano servicios de Inspección Digital Gratuita a vehículos Hyundai en e… https://t.co/TZ4XFziOd3 \Sector: AUTOMOCION
Calculemos los gradientes para cada token:
[23]:
import numpy as np
[24]:
interpretation = interpreter.saliency_interpret_from_json({"sentence": sample_text })
outputs = predictor.predict(sample_text)
grads = np.array(interpretation['instance_1']['grad_input_1'])
probs = np.array(outputs['probs'])
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py:974: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
[25]:
outputs.keys()
[25]:
dict_keys(['logits', 'probs', 'token_ids', 'label', 'tokens'])
Podemos graficar los resultados utilizando un mapa de calor marcando con colores más intensos aquellos tokens que tienen mayor impacto en las predicciones:
[28]:
from IPython.display import HTML
from eli5.formatters import format_as_html
from m72109.nlp.explanation import get_explanation_from_grads
[29]:
expl = get_explanation_from_grads(estimator_name="transformer",
estimator_description="NLP transformer explanation",
text=sample_text,
tokens=outputs['tokens'],
grads=grads,
probas=probs,
labels=labels)
[31]:
HTML(format_as_html(expl))
[31]:
NLP transformer explanation
y=AUTOMOCION (probability 0.900) top features
| Contribution? | Feature |
|---|---|
| +1.000 | Highlighted in text |
@HyundaiPeru con Grupo Primax realiza este verano servicios de Inspección Digital Gratuita a vehículos Hyundai en e… https://t.co/TZ4XFziOd3
[30]:
expl = get_explanation_from_grads(estimator_name="transformer",
estimator_description="NLP transformer explanation",
text=sample_text,
tokens=outputs['tokens'],
grads=grads,
probas=probs,
labels=labels,
force_weights=True)
[ ]:
HTML(format_as_html(expl, force_weights=True))
NLP transformer explanation
y=AUTOMOCION (probability 0.900, score 1.000) top features
| Weight? | Feature |
|---|---|
| +0.011 | [CLS] |
| +0.000 | @ |
| +0.000 | hyun |
| +0.029 | ##da |
| +0.062 | ##ipe |
| +0.029 | ##ru |
| +0.000 | con |
| +0.044 | grupo |
| +0.000 | prima |
| +0.000 | ##x |
| +0.029 | realiza |
| +0.022 | este |
| +0.015 | verano |
| +0.087 | servicios |
| +0.029 | de |
| +0.029 | inspección |
| +0.018 | digital |
| +0.000 | gratuita |
| +0.022 | a |
| +0.058 | vehículos |
| +0.116 | hyun |
| +0.044 | ##da |
| +0.044 | ##i |
| +0.008 | en |
| +0.022 | e |
| +0.003 | [UNK] |
| +0.029 | h |
| +0.022 | ##tt |
| +0.000 | ##ps |
| +0.000 | : |
| +0.015 | / |
| +0.007 | / |
| +0.029 | t |
| +0.011 | . |
| +0.015 | co |
| +0.005 | / |
| +0.000 | t |
| +0.002 | ##z |
| +0.044 | ##4 |
| +0.015 | ##x |
| +0.000 | ##f |
| +0.029 | ##zio |
| +0.000 | ##d |
| +0.058 | ##3 |
| +0.000 | [SEP] |
@HyundaiPeru con Grupo Primax realiza este verano servicios de Inspección Digital Gratuita a vehículos Hyundai en e… https://t.co/TZ4XFziOd3
Abrir en Google Colab
Descargar notebook