Open In Colab

셀프 어텐선 동작원리

import torch
x = torch.tensor([
    [1.0, 0.0, 1.0, 0.0],
    [0.0, 2.0, 0.0, 2.0],
    [1.0, 1.0, 1.0, 1.0],
])

w_query = torch.tensor([
    [1.0, 0.0, 1.0],
    [1.0, 0.0, 0.0],
    [0.0, 0.0, 1.0],
    [0.0, 1.0, 1.0],
])

w_key = torch.tensor([
    [0.0, 0.0, 1.0],
    [1.0, 1.0, 0.0],
    [0.0, 1.0, 0.0],
    [1.0, 1.0, 0.0],
])

w_value = torch.tensor([
    [0.0, 2.0, 0.0],
    [0.0, 3.0, 0.0],
    [1.0, 0.0, 3.0],
    [1.0, 1.0, 0.0],
])

변수를 정의합니다.

keys = torch.matmul(x, w_key)
querys = torch.matmul(x, w_query)
values = torch.matmul(x, w_value)

쿼리, 키, 벨류를 만듭니다.

attn_scores = torch.matmul(querys, keys.T)
attn_scores
tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])

어텐션 스코어를 만듭니다.

import numpy as np
from torch.nn.functional import softmax
key_dim_sqrt = np.sqrt(keys.shape[-1])
attn_probs = softmax(attn_scores / key_dim_sqrt, dim = -1)
attn_probs
tensor([[1.3613e-01, 4.3194e-01, 4.3194e-01],
        [8.9045e-04, 9.0884e-01, 9.0267e-02],
        [7.4449e-03, 7.5471e-01, 2.3785e-01]])

소프트맥스 확률값을 만듭니다.

weighted_values = torch.matmul(attn_probs, values)
weighted_values
tensor([[1.8639, 6.3194, 1.7042],
        [1.9991, 7.8141, 0.2735],
        [1.9926, 7.4796, 0.7359]])

소프트맥스 확률과 밸류를 가중합 하였습니다. 셀프 어텐션에 최종 출력값 입니다.

셀프 어텐션은 가중치 행렬(w_..) 3개를 학습 대상으로 생각하고 태스크를 잘 수행하는 방향으로 업데이트 됩니다.

피드포워드 뉴럴 네트워크 계산 예시

import torch

x = torch.tensor([2,1])
w1 = torch.tensor([[3,2,-4],[2,-3,1]])
b1 = 1
w2 = torch.tensor([[-1,1],[1,2],[3,1]])
b2 = -1

변수를 입력합니다.

h_preact = torch.matmul(x, w1) + b1
h = torch.nn.functional.relu(h_preact)
y = torch.matmul(h, w2) + b2

y
tensor([-8, 12])

결과 값 입니다. 여기서 w1, w2, b1, b2가 학습 대상이 됩니다.

학습 대상은 태스크를 잘 수행하는 방향으로 업데이트 됩니다.

m = torch.nn.Dropout(p = 0.2)
input = torch.randn(1,10)
output = m(input)
print(input)
print(output)
tensor([[ 0.3678,  1.6664,  1.6091,  0.1445,  0.9486,  2.0537,  2.4539,  1.4420,
          0.9701, -0.2877]])
tensor([[0.0000, 2.0830, 2.0113, 0.1806, 1.1858, 2.5672, 0.0000, 1.8025, 0.0000,
         -0.0000]])

간단한 드롭다웃 예제입니다. p 확률 만큼 뉴련을 0으로 대치해 계산에서 제외합니다.

문장을 벡터로 변환하기

!pip install ratsnlp
Collecting ratsnlp
  Downloading ratsnlp-0.0.9999-py3-none-any.whl (53 kB)
     |████████████████████████████████| 53 kB 1.6 MB/s 
Requirement already satisfied: flask>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.1.4)
Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.10.0+cu111)
Collecting pytorch-lightning==1.3.4
  Downloading pytorch_lightning-1.3.4-py3-none-any.whl (806 kB)
     |████████████████████████████████| 806 kB 15.9 MB/s 
Collecting flask-cors>=3.0.10
  Downloading Flask_Cors-3.0.10-py2.py3-none-any.whl (14 kB)
Collecting flask-ngrok>=0.0.25
  Downloading flask_ngrok-0.0.25-py3-none-any.whl (3.1 kB)
Collecting Korpora>=0.2.0
  Downloading Korpora-0.2.0-py3-none-any.whl (57 kB)
     |████████████████████████████████| 57 kB 6.4 MB/s 
Collecting transformers==4.10.0
  Downloading transformers-4.10.0-py3-none-any.whl (2.8 MB)
     |████████████████████████████████| 2.8 MB 37.5 MB/s 
Collecting PyYAML<=5.4.1,>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
     |████████████████████████████████| 636 kB 56.2 MB/s 
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (21.3)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (1.19.5)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (4.62.3)
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
     |████████████████████████████████| 829 kB 59.2 MB/s 
Collecting fsspec[http]>=2021.4.0
  Downloading fsspec-2021.11.1-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 52.0 MB/s 
Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2.7.0)
Collecting torchmetrics>=0.2.0
  Downloading torchmetrics-0.6.2-py3-none-any.whl (332 kB)
     |████████████████████████████████| 332 kB 64.2 MB/s 
Collecting pyDeprecate==0.3.0
  Downloading pyDeprecate-0.3.0-py3-none-any.whl (10 kB)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (3.4.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2019.12.20)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (4.8.2)
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
     |████████████████████████████████| 895 kB 44.7 MB/s 
Collecting huggingface-hub>=0.0.12
  Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)
     |████████████████████████████████| 61 kB 671 kB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
     |████████████████████████████████| 3.3 MB 48.7 MB/s 
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2.23.0)
Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (7.1.2)
Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (2.11.3)
Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.0.1)
Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.1.0)
Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from flask-cors>=3.0.10->ratsnlp) (1.15.0)
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
     |████████████████████████████████| 1.1 MB 46.2 MB/s 
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers==4.10.0->ratsnlp) (3.10.0.2)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->flask>=1.1.4->ratsnlp) (2.0.1)
Collecting xlrd>=1.2.0
  Downloading xlrd-2.0.1-py2.py3-none-any.whl (96 kB)
     |████████████████████████████████| 96 kB 6.6 MB/s 
Collecting dataclasses>=0.6
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->pytorch-lightning==1.3.4->ratsnlp) (3.0.6)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2021.10.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (1.24.3)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.37.0)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.12.0)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.6.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (57.4.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.8.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.42.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.3.6)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.6)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.17.3)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.35.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.2.4)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.2.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.10.0->ratsnlp) (3.6.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.1.1)
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
     |████████████████████████████████| 271 kB 69.4 MB/s 
Collecting multidict<7.0,>=4.5
  Downloading multidict-5.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (160 kB)
     |████████████████████████████████| 160 kB 57.9 MB/s 
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (21.2.0)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (2.0.8)
Collecting asynctest==0.13.0
  Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)
Collecting async-timeout<5.0,>=4.0.0a3
  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Collecting aiosignal>=1.1.2
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (192 kB)
     |████████████████████████████████| 192 kB 71.0 MB/s 
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.10.0->ratsnlp) (1.1.0)
Building wheels for collected packages: future
  Building wheel for future (setup.py) ... done
  Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491070 sha256=3cd467807e542363544f5ce4ef26f212f5b2dd8812ec646fa206e9b85b0c7b48
  Stored in directory: /root/.cache/pip/wheels/56/b0/fe/4410d17b32f1f0c3cf54cdfb2bc04d7b4b8f4ae377e2229ba0
Successfully built future
Installing collected packages: multidict, frozenlist, yarl, asynctest, async-timeout, aiosignal, PyYAML, fsspec, aiohttp, xlrd, torchmetrics, tokenizers, sacremoses, pyDeprecate, huggingface-hub, future, dataclasses, transformers, pytorch-lightning, Korpora, flask-ngrok, flask-cors, ratsnlp
  Attempting uninstall: PyYAML
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
  Attempting uninstall: xlrd
    Found existing installation: xlrd 1.1.0
    Uninstalling xlrd-1.1.0:
      Successfully uninstalled xlrd-1.1.0
  Attempting uninstall: future
    Found existing installation: future 0.16.0
    Uninstalling future-0.16.0:
      Successfully uninstalled future-0.16.0
Successfully installed Korpora-0.2.0 PyYAML-5.4.1 aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 dataclasses-0.6 flask-cors-3.0.10 flask-ngrok-0.0.25 frozenlist-1.2.0 fsspec-2021.11.1 future-0.18.2 huggingface-hub-0.2.1 multidict-5.2.0 pyDeprecate-0.3.0 pytorch-lightning-1.3.4 ratsnlp-0.0.9999 sacremoses-0.0.46 tokenizers-0.10.3 torchmetrics-0.6.2 transformers-4.10.0 xlrd-2.0.1 yarl-1.7.2
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    'beomi/kcbert-base',
    do_lower_case = False,
)

BERT(kcbert-base) 모델이 쓰는 토크나이저를 선언합니다.

from transformers import BertConfig, BertModel
pretrained_model_config = BertConfig.from_pretrained(
    'beomi/kcbert-base'
)
model = BertModel.from_pretrained(
    'beomi/kcbert-base',
    config = pretrained_model_config,
)
Some weights of the model checkpoint at beomi/kcbert-base were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

BERT(kcbert-base) 모델을 읽어들입니다.

pretrained_model_config
BertConfig {
  "_name_or_path": "beomi/kcbert-base",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 300,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.10.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30000
}

pretrained_model_config은 BERT 모델을 프리트레인 할때 설정했던 내용이 있습니다.

블록 수는 12개, 헤드 수는 12개, 어휘 집합 크기는 3만개 입니다.

sentences = ['안녕하세요', '하이!']
features = tokenizer(
    sentences,
    max_length = 10,
    padding = 'max_length',
    truncation = True,
)
features
{'input_ids': [[2, 19017, 8482, 3, 0, 0, 0, 0, 0, 0], [2, 15830, 5, 3, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]}

BERT 모델의 입력값을 만듭니다. 앞서 배운 BERT 모델과 같이 3개의 변수가 나옵니다.

features = {k : torch.tensor(v) for k, v in features.items()}

피처를 파이토치에 넣기 위해선 자료형이 텐서(tensor)이여야 하기 때문에 자료형을 변경했습니다.

outputs = model(**features)
outputs
BaseModelOutputWithPoolingAndCrossAttentions([('last_hidden_state',
                                               tensor([[[-0.6969, -0.8248,  1.7512,  ..., -0.3732,  0.7399,  1.1907],
                                                        [-1.4803, -0.4398,  0.9444,  ..., -0.7405, -0.0211,  1.3064],
                                                        [-1.4299, -0.5033, -0.2069,  ...,  0.1285, -0.2611,  1.6057],
                                                        ...,
                                                        [-1.4406,  0.3431,  1.4043,  ..., -0.0565,  0.8450, -0.2170],
                                                        [-1.3625, -0.2404,  1.1757,  ...,  0.8876, -0.1054,  0.0734],
                                                        [-1.4244,  0.1518,  1.2920,  ...,  0.0245,  0.7572,  0.0080]],
                                               
                                                       [[ 0.9371, -1.4749,  1.7351,  ..., -0.3426,  0.8050,  0.4031],
                                                        [ 1.6095, -1.7269,  2.7936,  ...,  0.3100, -0.4787, -1.2491],
                                                        [ 0.4861, -0.4569,  0.5712,  ..., -0.1769,  1.1253, -0.2756],
                                                        ...,
                                                        [ 1.2362, -0.6181,  2.0906,  ...,  1.3677,  0.8132, -0.2742],
                                                        [ 0.5409, -0.9652,  1.6237,  ...,  1.2395,  0.9185,  0.1782],
                                                        [ 1.9001, -0.5859,  3.0156,  ...,  1.4967,  0.1924, -0.4448]]],
                                                      grad_fn=<NativeLayerNormBackward0>)),
                                              ('pooler_output',
                                               tensor([[-0.1594,  0.0547,  0.1101,  ...,  0.2684,  0.1596, -0.9828],
                                                       [-0.9221,  0.2969, -0.0110,  ...,  0.4291,  0.0311, -0.9955]],
                                                      grad_fn=<TanhBackward0>))])

BERT 모델에 features를 적용했습니다. 두 개의 출력물 last_hidden_state, pooler_output이 나옵니다.

전자를 단어수준 임베딩, 후자를 문장수준 임베딩이라고 부릅니다.