Attention is all you need (2017)
๊ธฐ๊ณ ๋ฒ์ญ์ ๋ชฉ์ ์ผ๋ก ๋ง๋ค์ด ์ก์ ๊ฒ์ผ๋ก ์์- ์ธ์ฝ๋:
- skip ์ปค๋ฅ์
(resnet), Batch Norm(yolov2) ์ ๊ฐ๊ณ ์๋ค.
- ๋์ฝ๋:
- ํฌ์ง์ ๋ ์ธ์ฝ๋ฉ
- ๋ณ๋ ฌ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ค. ์์ฐจ์ ์ผ๋ก ์ฒ๋ฆฌํ์ง ์๋๋ค.
- ์์ฐจ์ ์ด์ง ์๊ธฐ ๋๋ฌธ์ ์์น์ ๊ณ ์ ๊ฐ์ ๋ถ์ฌํ๋ค.
Attention(Q, K, V)๋ ์์ฐจ์ ์ผ ํ์๊ฐ ์๋ค.
- Q: query, t์์ (t์ํ์ค)์ ๋์ฝ๋ LTSTM์ ์๋๊ฐ (๋ฒกํฐ ํํ์ hidden)
- K: key, ๋ชจ๋ ์์ ์ ์ธ์ฝ๋ ์ ์ ์๋๊ฐ (๋ฒกํฐ ํํ์ hidden)
- V: =K
- ์ธ์ฝ๋ ์์์ ์ ์ฒด ๋จ์ด์ ๊ฐ ๋จ์ด ๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ตฌํ๋ค.
- skip ์ปค๋ฅ์ (resnet), Batch Norm(yolov2) ์ ์ฌ์ฉํ๋ค.
๋ ผ๋ฌธ ์ดํด๋ณด๊ธฐ
1) ๋ชจ๋ธ์ ๊ตฌ์กฐ, 2) ํ์ต ๋ฐฉ์, 3) ์ฑ๋ฅ๊ฐ์ ์ ๋ต
RNN์ Attention์ ์ฌ์ฉํ๋ ๋ณต์กํ ๋คํธ์ํฌ ๊ตฌ์กฐ์์ Attention๋ง ์ฌ์ฉํ๋ ๋จ์ํ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ค. ๊ณผ์ฐ, ์ด๋ป๊ฒ ์์ ์ ๋ณด ์์ด ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ก ๊ฐ๋ฅํ ๊น? (์ดํ
์
& ํฌ์ง์
๋์ธ์ฝ๋ฉ)
๋ณ๋ ฌ ์ฒ๋ฆฌ์ ์ฅ์ ์ ๋ฌด์์ธ๊ฐ?
ํ ๋ฒ์ ์
๋ ฅ ๋ฐ๋๋ค -> ํ์ต ์๋๊ฐ ๋น ๋ฅด๋ค.
์์ ์ ๋ณด๊ฐ ํ์ต๋์ด ์์ผ๋ฉด, ์์ ์ ๋ณด ์์ด ์
๋ ฅ์ด ๋ค์ด์๋ ๋ฌธ๋งฅ ํ์
์ด ๊ฐ๋ฅํ๋ค.
Model Architecture
Encode: [two sub-layers with skip-connection] ๋ฅผ 6๊ฐ, ์ด 12์ธต
1. ์ฟผ๋ฆฌ, ํค, ๋ฐธ๋ฅ๊ฐ ์ธ์ฝ๋์์ ๋ชจ๋ ์ฒ๋ฆฌ๋๋ค.
- Queries: ํ์ฌ ์ฒ๋ฆฌ ๋จ์ด, ๊ธฐ์กด์ ๋์ฝ๋์์ ํ๋์ฉ ์ ๋ ฅํ๋ ๊ฒ
- keys: ๊ฐ ํ ํฐ๊ณผ์ ์๊ด๊ด๊ณ, ์๋ฅผ ๋ค์ด ์ ์ฌ๋
- values: ๋จ์ด์ ์๋ฏธ
2. Scaled Dot-Product Attention
> MatMul
> Scale(ํฌ๊ธฐ ์กฐ์ )
> Mask(ํน์ ํ ํฐ์ ์ ์ฌ๋ ๊ณ์ฐ ๋นํ์ฑํ. ์์ฃผ ํฐ ์์๋ฅผ ๊ณฑํจ์ผ๋ก์จ, y๊ฐ์ 0์ ๊ฐ๊น๊ฒ ๋ง๋ค์ด ๋นํ์ฑํ ์ํจ๋ค. ex:ํจ๋ฉ)
> SoftMax (์ฌ๊ธฐ๊น์ง๊ฐ ๊ฐ์ค์น๋ฅผ ๋ง๋๋ ๋จ๊ณ๋ค.)
> MatMul(SoftMax์ Value๋ฅผ ๊ณฑํ๋ค.)
> Mask(ํน์ ํ ํฐ์ ์ ์ฌ๋ ๊ณ์ฐ ๋นํ์ฑํ. ์์ฃผ ํฐ ์์๋ฅผ ๊ณฑํจ์ผ๋ก์จ, y๊ฐ์ 0์ ๊ฐ๊น๊ฒ ๋ง๋ค์ด ๋นํ์ฑํ ์ํจ๋ค. ex:ํจ๋ฉ)
> SoftMax (์ฌ๊ธฐ๊น์ง๊ฐ ๊ฐ์ค์น๋ฅผ ๋ง๋๋ ๋จ๊ณ๋ค.)
> MatMul(SoftMax์ Value๋ฅผ ๊ณฑํ๋ค.)
ํ ๋ฒ์ ์ดํ
์
์ ํ๋ ๊ฒ๋ณด๋ค ์ฌ๋ฌ ๋ฒ์ ์ดํ
์
์ ๋ณ๋ ฌ๋ก ์ฌ์ฉํ๋ ๊ฒ์ด ํจ๊ณผ์ ์ด๋ผ๊ณ ํ๋จ.
heads๋ฅผ 8๊ฐ๋ฅผ ๋๋ค. Multi-Head๋ 8๊ฐ์ ๋ฌถ์์ด๋ค.
3. Feed Forward
์๊ท๋ชจ ์์ ์ฐ๊ฒฐ์ธต
FFX(x) = max(0, xW1 + b1)W2 + b2
512 > 2048 > 512
Decoder: [multi-head attention * 2 + FC] ๊ฐ 6๊ฐ, ์ด 18์ธต
Seq2Seq์ ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ Attention๋ง ์๋ ํธ๋์คํฌ๋จธ์ ๊ฒฝ์ฐ,
๋์ฝ๋๋ฅผ ์์ํ๋ ์ฒซ๋ฒ์งธ Attention์ Context Vector๋ฅผ ์ฌ์ฉํ์ง ์๋๋ค.
์ธ์ฝ๋์ ๊ฒฐ๊ณผ๊ฐ(๋ฌธ๋งฅํ์ )์ ์ฐธ๊ณ ํ์ง ์๊ณ ๋จผ์ Attention์ ๊ตฌํ๊ณ
๊ทธ ๋ค์์ ์ธ์ฝ๋ ๊ฐ์ ์ฐธ์กฐํ๋ค.
์ธ์ฝ๋์ ๊ฒฐ๊ณผ๊ฐ(๋ฌธ๋งฅํ์ )์ ์ฐธ๊ณ ํ์ง ์๊ณ ๋จผ์ Attention์ ๊ตฌํ๊ณ
๊ทธ ๋ค์์ ์ธ์ฝ๋ ๊ฐ์ ์ฐธ์กฐํ๋ค.
๋ํ, ์ฒซ๋ฒ์งธ Attention์ ๋ง์คํน ์ฒ๋ฆฌ๊ฐ ์๋ค.
(์ธ์ฝ๋๋ ์๋ฐฉํฅ, ๋์ฝ๋๋ ๋จ๋ฐฉํฅ)
(์ ๋ ฅ์ ๋ณ๋ ฌ๋ก ํ๋๋ผ๋, ์ถ๋ ฅ์ ์์ฐจ์ ์ผ๋ก ํด์ผ ํ๋ค.)
(์ธ์ฝ๋๋ ์๋ฐฉํฅ, ๋์ฝ๋๋ ๋จ๋ฐฉํฅ)
(์ ๋ ฅ์ ๋ณ๋ ฌ๋ก ํ๋๋ผ๋, ์ถ๋ ฅ์ ์์ฐจ์ ์ผ๋ก ํด์ผ ํ๋ค.)
์๊ธฐ ์์ ์๋ ์ ๋ค๊น์ง๋ง ์ ์ฌ๋ ๊ณ์ฐ์ ํ๊ณ ,
๋ค์ ๋์ฌ ํ ํฐ์ ์ ์ฌ๋ ๊ณ์ฐ์ ํ์ง ์๋๋ค. (์ฐ์ธก ํ๋จ ๋ น์ ๊ณ๋จ ์ฐธ๊ณ )
๋ค์ ๋์ฌ ํ ํฐ์ ์ ์ฌ๋ ๊ณ์ฐ์ ํ์ง ์๋๋ค. (์ฐ์ธก ํ๋จ ๋ น์ ๊ณ๋จ ์ฐธ๊ณ )

๋๋ฒ์งธ Attention - ์ธ์ฝ๋ ๋์ฝ๋ ์ดํ
์
:
Query = ๋์ฝ๋ ํ๋ ฌ
key, Value = ์ธ์ฝ๋ ํ๋ ฌ
์ดํ ์ (์ ํ ์ดํ ์ , ๋จ์ด์ ๋ฌธ๋งฅ ํ์ )
ํฌ์ง์ ๋ ์ธ์ฝ๋ฉ
ex)
1. ์๋ฒ ๋ฉ: ๋๋ ์ค๋ ํ๊ต์ ๊ฐ๋ค -> ๊ฐ ๋จ์ด๋ฅผ ๋ฒกํฐํ
2. ์์น์ ๊ณ ์ ๊ฐ, PE (๋์ผํ n์ฐจ์)
- ์ด๋ค ํน์ ํ ํฐ์ ๋ค๋ฅธ ํ ํฐ์ผ๋ก ์ค๋ช ํ ์ ์๋ค.
- pos๋ ์์น ๊ฐ: 0, 1, 2, 3
- i๋ ๋จ์ด ํ๋ ์๋ฒ ๋ฉ ๋ฒกํฐ์ ์ฐจ์ ๊ฐ
- ์ sin, cos ์ธ๊ฐ? ๊ฒน์น์ง์ง๋ง ์์ผ๋ฉด ๋์ง ์์๊น? ๋๋ค ๋๋ ค๋ ๋ ํ ๋ฐ?
- ๊ฒฐ๊ณผ ๊ฐ์ด ์ ํ(๋จ์ํ) ๊ด๊ณ๊ฐ ๋ ์ ์๋๋ก,
์ ํ์ผ๋ก ์์ธก ๊ฐ๋ฅํ ๋น์ ํ์ ์ฃผ์ ํ๋ค. ํ์ต์ ๋ช ๋ฃํจ์ ์ฃผ๊ธฐ ์ํด. - sin(์ง์) ๊ณผ cos(ํ์) ๋ ๊ฐ๋ฅผ ์ฌ์ฉํ๋ฉด ๊ฒน์น ์ผ์ด ์์ํ๋ค.
- 0(2i): sin(... i=0)
- 1(2i+1): cos(...i=0)
- 2: sin(....i=1์ด๋ฏ๋ก 2)
- 3: cos(... i=1์ด๋ฏ๋ก 2)
- 10000์ ๋ฌด์์ธ๊ฐ? ๋จ์ ์คํ ๊ฐ.
- ์ง์๋ก ์ฒ๋ฆฌํ ์ด์ ๋?
- ์ธ๊ฐ์ด ์ฐ๊ตฌํ ๊ฐ ์ค, ์ต์ํ ๊ฐ์ ๋ก๊ทธ ๊ฐ. ๊ทธ๋์ ๋ก๊ทธ์ ์ญํจ์์ธ ์ง์๋ฅผ ์ฌ์ฉ
3. ๋จ์ด ๋ฒกํฐํ(n์ฐจ์) + ์์น๊ฐ(๋์ผํ n์ฐจ์) = ์ธํ(n์ฐจ์)
BERT & GPT
BERT์ ํต์ฌ์ ๋ถ๋ฅ, GPT๋ ์์ฑํ ์ฑ๋ด์ด ๋ชฉ์ ์ด์๋ค.
๋ถ๋ฅ ๋ ์ด๋ธ ์์ด ๋ฌธ์๋ง ๊ฐ์ง๊ณ ํ์ตํ๋ค.
- ๋์ฝ๋๋ง ์ฌ์ฉํ๋ ๋จ๋ฐฉํฅ
BERT
- ์ธ์ฝ๋๋ง ์ฌ์ฉํ๋ ์๋ฐฉํฅ
- ๋ฌธ์ฅ์ ํน์ ๋จ์ด์ ๋ง์คํฌ ์ฒ๋ฆฌ๋ฅผ ํด์, ๋จ์ด๋ฅผ ๋ง์ถ๋ ํ์ต์ ํ๋ค.
- ์๋ธ์๋ ํ ํฌ๋์ด์ ์ฌ์ฉ
- NSP: ๋ค์ ๋ฌธ์ฅ ์์ธก. ๋ ๊ฐ์ ๋ฌธ์ฅ์ ์ค ํ์ ์ด ๋ฌธ์ฅ์ด ์ด์ด์ง๋ ๋ฌธ์ฅ์ธ์ง ์๋์ง ๋ง์ถ๋ ๋ฐฉ์
์ค์ต
# ํ ํฐํ๋ฅผ ์ํ ํ ํฌ๋์ด์
from transformers import BertTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
์ ์ฒ๋ฆฌ
์ง๋ฌธ๊ณผ ๋ต์ด ์์ ์ด๋ฃจ๋ ํ์ต๋ฐ์ดํฐ
ํจ๋ฉ ๋ฐ ์ธ์ฝ๋ฉ
# ์ง๋ฌธ ํจ๋ฉ ๋ฐ ์ธ์ฝ๋ฉ
def src_encoding(tokens):
# (src_length - ํ ํฐ ์) ๋งํผ ํจ๋ฉํ ํฐ ์ถ๊ฐ - ํจ๋ฉํ ํฐ์ ์ถ๊ฐํ์ฌ ์ต๋๊ธธ์ด๋ก ๋ง์ถค
tokens = tokens + ['<PAD>'] * (src_length - len(tokens))
# ์ธ์ฝ๋ฉ ๋ ์ซ์๋ฅผ ๋ด์ ๋ ๋ฆฌ์คํธ
index_sequences = []
# ๋ฌธ์ฅ์์ ํ ํฐ์ ๊บผ๋ด์ค๋ฉฐ
for word in tokens:
try: # ํ ํฐ ์ธ์ฝ๋ฉ (๋จ์ด ์ฌ์ ์ ์๋ ํ ํฐ์ด ๋ค์ด์ค๋ฉด except๋ก)
index_sequences.append(vocab[word])
except KeyError: # ๋จ์ด ์ฌ์ ์ ์๋ ํ ํฐ์ด ๋ค์ด์ค๋ฉด '<UNK>' ํ ํฐ์ ์ซ์๋ก ๋ณํ
index_sequences.append(vocab['<UNK>'])
return index_sequences
# ๋ต๋ณ๋ฌธ ํจ๋ฉ ๋ฐ ์ธ์ฝ๋ฉ
def trg_encoding(tokens):
# (src_length - ํ ํฐ ์) ๋งํผ ํจ๋ฉํ ํฐ ์ถ๊ฐ
tokens = tokens + ['<PAD>'] * (trg_length - len(tokens))
# ์ธ์ฝ๋ฉ ๋ ์ซ์๋ฅผ ๋ด์ ๋ ๋ฆฌ์คํธ
index_sequences = []
# ๋ฌธ์ฅ์์ ํ ํฐ์ ๊บผ๋ด์ค๋ฉฐ
for word in tokens:
try: # ํ ํฐ ์ธ์ฝ๋ฉ (๋จ์ด ์ฌ์ ์ ์๋ ํ ํฐ์ด ๋ค์ด์ค๋ฉด except๋ก)
index_sequences.append(vocab[word])
except KeyError: # ๋จ์ด ์ฌ์ ์ ์๋ ํ ํฐ์ด ๋ค์ด์ค๋ฉด '<UNK>' ํ ํฐ์ ์ซ์๋ก ๋ณํ
index_sequences.append(vocab['<UNK>'])
return index_sequences
๋ฐ์ดํฐ์
ํด๋์ค ์์ฑ -> ๋ฐ์ดํฐํ๋ ์์์ ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์์ ํ
์๋ก ๋ณํ
ํธ๋์คํฌ๋จธ ๋ชจ๋ธ์ ๊ตฌ์ถ ์์
1. ํฌ์ง์
๋ ์ธ์ฝ๋ฉ ํด๋์ค
2. ๋ง์คํน ํจ์
3. ๋ชจ๋ธ ์ค๊ณ
# ์ํ ์ฐ์ฐ์ ์ํ math ๋ผ์ด๋ธ๋ฌ๋ฆฌ
import math
class PositionalEncoding(nn.Module):
# ํด๋์ค ์์ฑ ์ emb_size, dropout, maxlen์ ์
๋ ฅ์ผ๋ก ๋ฐ์
def __init__(self, emb_size, dropout, maxlen=5000):
super(PositionalEncoding, self).__init__()
# (maxlen * emb_size)ํฌ๊ธฐ์ PostionalEncoding ํ๋ ฌ ์์ฑ ํ ์๋ณธ๊ณผ ํ๊ฒ์
๋ ฅ๊ฐ์ ๋ํด์ค
# den :10000^(2i/d_model) ๊ตฌํ
den = 10000 ** (torch.arange(0, emb_size, 2) / emb_size)
# position
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
# ์ ํ๋ ฌ ์์ฑ(maxlen, emb_size)
pos_embedding = torch.zeros((maxlen, emb_size))
# ์ ํ๋ ฌ์ ์ง์ ์ด์ ์ฌ์ธํจ์ ์ ์ฉ(0,2,4,6 .....)
pos_embedding[:, 0::2] = torch.sin(pos * den)
# ์ ํ๋ ฌ์ ํ์ ์ด์ ์ฝ์ฌ์ธ ํจ์ ์ ์ฉ(1,3,5,7)
pos_embedding[:, 1::2] = torch.cos(pos * den)
# pistional embedding์ ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ(1, maxlen, embsize)
pos_embedding = pos_embedding.unsqueeze(0)
# ๋
ผ๋ฌธ์์๋ ํฌ์ง์
๋์ธ์ฝ๋ฉ์ ๊ฑฐ์น ํ ๋๋กญ์์์ ์ ์ฉ
self.dropout = nn.Dropout(dropout)
# PyTorch์ nn.Module์์ ์ ๊ณตํ๋ ๊ธฐ๋ฅ์ผ๋ก, ๋ชจ๋ธ์ ์ํ์ ํฌํจ๋์ง๋ง ํ์ต๋์ง ์๋ ํ
์๋ฅผ ๋ฑ๋กํ ๋ ์ฌ์ฉ
# PositionalEncoding ๋ฒกํฐ๋ ๋ณํ์ง ์๊ณ ๊ณ ์
self.register_buffer('pos_embedding', pos_embedding)
def forward(self, token_embedding):
# ์๋ฒ ๋ฉ ๋ฒกํฐ์(batch, seq_length, embsize) PE๋ฒกํฐ(1, maxlen, embsize) ๋ํจ
# ๋ฐฐ์น์ ๊ฐ ๋ฐ์ดํฐ๋ค์๋ํด ๊ฐ๊ฐ PE ๋ํจ
# ์ดํ ๋๋กญ์์ ์ ์ฉ
return self.dropout(token_embedding + self.pos_embedding[:, :token_embedding.size(1), :])
class Seq2SeqTransformer(nn.Module):
def __init__(self, num_encoder_layers, num_decoder_layers, emb_size,
nhead, src_vocab_size, tgt_vocab_size, dim_feedforward=512,
dropout=0.1):
super(Seq2SeqTransformer, self).__init__()
# ๋ชจ๋ธ์์ ์ฌ์ฉํ ๋ ์ด์ด ์ ์
# ์๋ฒ ๋ฉ ๋ ์ด์ด
self.emb_size = emb_size
self.embedding = nn.Embedding(vocab_size, emb_size)
# ํฌ์ง์
๋ ์ธ์ฝ๋ฉ ๋ ์ด์ด
self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
# ํธ๋์คํฌ๋จธ ๋ ์ด์ด(์ธ์ฝ๋+๋์ฝ๋)
# d_model : ์
์ถ๋ ฅ ์ฐจ์ ์, num_encoder(decoder)_layers : ์ธ์ฝ๋(๋์ฝ๋)์ ์ธต ์)
# dim_feedforward : FFNN์ ์ฐจ์ ์(attention ๊ฑฐ์น ํ FFNN์์ dmodel -> dim_feedforward -> dmodel)
self.transformer = nn.Transformer(d_model=emb_size,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True)
# ์ถ๋ ฅ ๋ ์ด์ด(๋จ์ด์ฌ์ ์ ์๋ ๋จ์ด๋ก ์ถ๋ ฅ๋ ์ ์๋๋ก)
self.fc = nn.Linear(emb_size, tgt_vocab_size)
def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
# ์ง๋ฌธ, ๋ต๋ณ๋ฌธ ์๋ฒ ๋ฉ
# ๋
ผ๋ฌธ์์๋ ์๋ฒ ๋ฉ ๋ฒกํฐ์ ์ค์ผ์ผ์ ์กฐ์ ํ์ฌ ์ด๊ธฐํ ๊ณผ์ ์์์ ๋ถ์์ ์ฑ์ ์ค์ด๊ธฐ ์ํด ์๋ฒ ๋ฉ ํฌ๊ธฐ์ ์ ๊ณฑ๊ทผ์ ๊ณฑํจ
src, tgt = src.long(), tgt.long()
src = self.embedding(src) * math.sqrt(self.emb_size)
tgt = self.embedding(tgt) * math.sqrt(self.emb_size)
# ํฌ์ง์
๋ ์ธ์ฝ๋ฉ ์ ์ฉ
src_emb = self.positional_encoding(src)
tgt_emb = self.positional_encoding(tgt)
# ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํต๊ณผ
# src_emb : ์ง๋ฌธ ๋ฐ์ดํฐ, tgt_emb : ๋ต๋ณ๋ฌธ ๋ฐ์ดํฐ, src_mask : ์ง๋ฌธ ๋ง์คํฌ, tgt_mask : ๋ต๋ณ๋ฌธ ๋ง์คํฌ(์ดํ
์
์ ๋ฏธ๋ ์์ ๋ชป๋ณด๋๋ก)
# src(tgt)_padding_mask : ์ง๋ฌธ(๋ต๋ณ๋ฌธ)์ด ์ดํ
์
์ ํจ๋ฉํ ํฐ์ ์ ์ฉ๋์ง ์๋๋ก ํจ๋ฉ ํ ํฐ์ ๋ง์คํน-> self-atteniton์ ์ ์ฉ
# memory_key_padding_mask : ๋์ฝ๋๊ฐ ์ธ์ฝ๋์ ์ถ๋ ฅ ์ ๋ณด ํ์ฉ ์ ํจ๋ฉ๋ ๋ถ๋ถ ํ์ฉํ์ง ์๋๋ก ๋ง์ค -> encoder-decoder attention์ ์ ์ฉ
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, # None ๋ถ๋ถ์ memory ๋ง์คํฌ๋ก ๋ณดํต None์ผ๋ก ๋
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
outs = self.fc(outs)
return outs # (๋ฐฐ์น์ฌ์ด์ฆ, ์ถ๋ ฅ์ํ์ค ๊ธธ์ด, ๋จ์ด ์ฌ์ ์)
# ์์ธก ์ ์ ์ฉํ ์ธ์ฝ๋ฉ ํจ์
def encode(self, src, src_mask):
# ์์ธก ์ ์ฌ์ฉํ ์ธ์ฝ๋ ํต๊ณผ(์ง๋ฌธ -> ์๋ฒ ๋ฉ -> ํฌ์ง์
๋์ธ์ฝ๋ฉ -> ํธ๋์คํฌ๋จธ ์ธ์ฝ๋ ํต๊ณผ)
src = src.long()
src = self.embedding(src) * math.sqrt(self.emb_size)
src_emb = self.positional_encoding(src)
# ์ค์ ์์ธก ์ ํจ๋ฉ์ ๋ฃ์ง ์๊ธฐ์ ํจ๋ฉ ๋ง์คํฌ๋ ์ ์ฉ X
outs = self.transformer.encoder(src_emb, src_mask)
return outs
# ์์ธก ์ ์ ์ฉํ ๋์ฝ๋ฉ ํจ์
def decode(self, tgt, memory, tgt_mask):
# ์์ธก ์ ๋์ฝ๋ ํต๊ณผ(๋ต๋ณ๋ฌธ -> ์๋ฒ ๋ฉ -> ํฌ์ง์
๋์ธ์ฝ๋ฉ -> ํธ๋์คํฌ๋จธ ๋์ฝ๋ ํต๊ณผ)
# ์์ธก ์ tgt๋ sos ํ ํฐ, memory๋ ์ธ์ฝ๋์ ์ถ๋ ฅ(๋งค ์ธต์์ ์ธ์ฝ๋ ๋์ฝ๋ ์ดํ
์
์ ํ์ฉ)
tgt = tgt.long()
tgt = self.embedding(tgt) * math.sqrt(self.emb_size)
tgt_emb = self.positional_encoding(tgt)
outs = self.transformer.decoder(tgt_emb, memory, tgt_mask)
return outs
# ๋ฏธ๋ ์์ ๋ง์คํฌ๋ฅผ ์ํ ํจ์ ์์ฑ
def generate_square_subsequent_mask(sz):
# torch.ones((sz, sz) : (sz, sz)ํฌ๊ธฐ์ 1๋ก ์ฑ์์ง ํ๋ ฌ ์์ฑ
# torch.triu() : ๋๊ฐ์ ๊ธฐ์ค์ผ๋ก ์๋์ ์์๋ฅผ 0์ผ๋ก ๋ณ๊ฒฝ
# 1 1 1
# 0 1 1
# 0 0 1
# ์ ์น
# 1 0 0
# 1 1 0
# 1 1 1
mask = torch.triu(torch.ones((sz, sz), device=device)).transpose(0, 1)
# 0์ธ ๋ถ๋ถ์ ์์ ๋ฌดํ ๊ฐ์ผ๋ก, 1์ธ ๋ถ๋ถ์ 0.0์ผ๋ก ๋ณ๊ฒฝ
# ํ์ผ๋ฌธ์ ์ํ์ค ๊ธธ์ด ๋งํผ์ ๋ง์คํฌ ํ๋ ฌ์ ์์ฑํ์ฌ self-attention ์ ๋ฏธ๋ ์์ ์ ๊ฐ๋ค์ ์์ ๋ฌดํ๋๋ก ๊ฐ ์ ์๋๋ก
# ์์ ๋ฌดํ๋ ๊ฐ์ softmax ํต๊ณผ ์ 0์ผ๋ก ์๋ ด
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
# ์ ์ฒด ๋ง์คํฌ(ํจ๋ฉ, ๋ฏธ๋์์ )๋ฅผ ์์ฑํ๊ธฐ ์ํ ํจ์
def create_mask(src, tgt):
# ์ง๋ฌธ๊ณผ ๋ต๋ณ๋ฌธ์ ์ํ์ค ๊ธธ์ด ํ์ธ
src_seq_len = src.shape[1]
tgt_seq_len = tgt.shape[1]
# ๋ฏธ๋์์ ๋ง์คํฌ
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
# ์ง๋ฌธ์ ๋ง์คํน์ ์ ์ฉํ์ง ์๊ธฐ์ False๋ก ์ด๋ฃจ์ด์ง ํ๋ ฌ ์์ฑ
src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)
# ํจ๋ฉ๋ง์คํฌ ์์ฑ(์
๋ ฅ ๋ฐ์ดํฐ ํ๋ ฌ์์ ์
๋ ฅ ํ ํฐ์ด 0(<PAD>)์ธ ๋ถ๋ถ์ True๋ก ๋๋จธ์ง๋ False์ธ ํํ๋ก ๋ฐํ)
src_padding_mask = (src == 0)
tgt_padding_mask = (tgt == 0)
return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
Tags:
AI๊ฐ๋ฐ_๊ต์ก