NLP Transformer

Attention is all you need (2017)

๊ธฐ๊ณ„ ๋ฒˆ์—ญ์„ ๋ชฉ์ ์œผ๋กœ ๋งŒ๋“ค์–ด ์กŒ์„ ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒ
  • ์ธ์ฝ”๋”: 
    • skip ์ปค๋„ฅ์…˜(resnet), Batch Norm(yolov2) ์„ ๊ฐ–๊ณ  ์žˆ๋‹ค.

  • ๋””์ฝ”๋”:
    • ํฌ์ง€์…”๋„ ์ธ์ฝ”๋”ฉ
      • ๋ณ‘๋ ฌ์ ์œผ๋กœ ์ฒ˜๋ฆฌํ•œ๋‹ค. ์ˆœ์ฐจ์ ์œผ๋กœ ์ฒ˜๋ฆฌํ•˜์ง€ ์•Š๋Š”๋‹ค.
      • ์ˆœ์ฐจ์ ์ด์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์œ„์น˜์  ๊ณ ์œ  ๊ฐ’์„ ๋ถ€์—ฌํ•œ๋‹ค.


Attention(Q, K, V)๋Š” ์ˆœ์ฐจ์ ์ผ ํ•„์š”๊ฐ€ ์—†๋‹ค.

  • Q: query, t์‹œ์ (t์‹œํ€€์Šค)์˜ ๋””์ฝ”๋” LTSTM์˜ ์€๋‹‰๊ฐ’ (๋ฒกํ„ฐ ํ˜•ํƒœ์˜ hidden)
  • K: key, ๋ชจ๋“  ์‹œ์ ์˜ ์ธ์ฝ”๋” ์…€์˜ ์€๋‹‰๊ฐ’ (๋ฒกํ„ฐ ํ˜•ํƒœ์˜ hidden)
  • V: =K

Self-Attention

  • ์ธ์ฝ”๋” ์•ˆ์—์„œ ์ „์ฒด ๋‹จ์–ด์™€ ๊ฐ ๋‹จ์–ด ๊ฐ„์— ์œ ์‚ฌ๋„๋ฅผ ๊ตฌํ•œ๋‹ค.
  • skip ์ปค๋„ฅ์…˜(resnet), Batch Norm(yolov2) ์„ ์‚ฌ์šฉํ•œ๋‹ค.

๋…ผ๋ฌธ ์‚ดํŽด๋ณด๊ธฐ

1) ๋ชจ๋ธ์˜ ๊ตฌ์กฐ, 2) ํ•™์Šต ๋ฐฉ์‹, 3) ์„ฑ๋Šฅ๊ฐœ์„  ์ „๋žต

RNN์— Attention์„ ์‚ฌ์šฉํ–ˆ๋˜ ๋ณต์žกํ•œ ๋„คํŠธ์›Œํฌ ๊ตฌ์กฐ์—์„œ Attention๋งŒ ์‚ฌ์šฉํ•˜๋Š” ๋‹จ์ˆœํ•œ ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ–ˆ๋‹ค. ๊ณผ์—ฐ, ์–ด๋–ป๊ฒŒ ์ˆœ์„œ ์ •๋ณด ์—†์ด ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋กœ ๊ฐ€๋Šฅํ• ๊นŒ? (์–ดํ…์…˜ & ํฌ์ง€์…”๋„์ธ์ฝ”๋”ฉ)

๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ์˜ ์žฅ์ ์€ ๋ฌด์—‡์ธ๊ฐ€? 
ํ•œ ๋ฒˆ์— ์ž…๋ ฅ ๋ฐ›๋Š”๋‹ค -> ํ•™์Šต ์†๋„๊ฐ€ ๋น ๋ฅด๋‹ค.

์ˆœ์„œ ์ •๋ณด๊ฐ€ ํ•™์Šต๋˜์–ด ์žˆ์œผ๋ฉด, ์ˆœ์„œ ์ •๋ณด ์—†์ด ์ž…๋ ฅ์ด ๋“ค์–ด์™€๋„ ๋ฌธ๋งฅ ํŒŒ์•…์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

Model Architecture

Encode: [two sub-layers with skip-connection] ๋ฅผ 6๊ฐœ, ์ด 12์ธต

1. ์ฟผ๋ฆฌ, ํ‚ค, ๋ฐธ๋ฅ˜๊ฐ€ ์ธ์ฝ”๋”์—์„œ ๋ชจ๋‘ ์ฒ˜๋ฆฌ๋œ๋‹ค.
  • Queries: ํ˜„์žฌ ์ฒ˜๋ฆฌ ๋‹จ์–ด, ๊ธฐ์กด์— ๋””์ฝ”๋”์—์„œ ํ•˜๋‚˜์”ฉ ์ž…๋ ฅํ•˜๋˜ ๊ฒƒ
  • keys: ๊ฐ ํ† ํฐ๊ณผ์˜ ์ƒ๊ด€๊ด€๊ณ„, ์˜ˆ๋ฅผ ๋“ค์–ด ์œ ์‚ฌ๋„
  • values: ๋‹จ์–ด์˜ ์˜๋ฏธ

2. Scaled Dot-Product Attention

> MatMul
> Scale(ํฌ๊ธฐ ์กฐ์ ˆ)
> Mask(ํŠน์ • ํ† ํฐ์„ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”. ์•„์ฃผ ํฐ ์Œ์ˆ˜๋ฅผ ๊ณฑํ•จ์œผ๋กœ์จ, y๊ฐ’์„ 0์— ๊ฐ€๊น๊ฒŒ ๋งŒ๋“ค์–ด ๋น„ํ™œ์„ฑํ™” ์‹œํ‚จ๋‹ค. ex:ํŒจ๋”ฉ)
> SoftMax (์—ฌ๊ธฐ๊นŒ์ง€๊ฐ€ ๊ฐ€์ค‘์น˜๋ฅผ ๋งŒ๋“œ๋Š” ๋‹จ๊ณ„๋‹ค.)
> MatMul(SoftMax์— Value๋ฅผ ๊ณฑํ•œ๋‹ค.)


ํ•œ ๋ฒˆ์˜ ์–ดํ…์…˜์„ ํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ์—ฌ๋Ÿฌ ๋ฒˆ์˜ ์–ดํ…์…˜์„ ๋ณ‘๋ ฌ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ํšจ๊ณผ์ ์ด๋ผ๊ณ  ํŒ๋‹จ.
heads๋ฅผ 8๊ฐœ๋ฅผ ๋‘”๋‹ค. Multi-Head๋Š” 8๊ฐœ์˜ ๋ฌถ์Œ์ด๋‹ค.


3. Feed Forward
์†Œ๊ทœ๋ชจ ์™„์ „์—ฐ๊ฒฐ์ธต

FFX(x) = max(0, xW1 + b1)W2 + b2
512 > 2048 > 512



Decoder: [multi-head attention * 2 + FC] ๊ฐ€ 6๊ฐœ, ์ด 18์ธต

Seq2Seq์˜ ๋ชจ๋ธ๊ณผ ๋‹ฌ๋ฆฌ Attention๋งŒ ์žˆ๋Š” ํŠธ๋žœ์Šคํฌ๋จธ์˜ ๊ฒฝ์šฐ,
๋””์ฝ”๋”๋ฅผ ์‹œ์ž‘ํ•˜๋Š” ์ฒซ๋ฒˆ์งธ Attention์— Context Vector๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค.
์ธ์ฝ”๋”์˜ ๊ฒฐ๊ณผ๊ฐ’(๋ฌธ๋งฅํŒŒ์•…)์„ ์ฐธ๊ณ ํ•˜์ง€ ์•Š๊ณ  ๋จผ์ € Attention์„ ๊ตฌํ•˜๊ณ 
๊ทธ ๋‹ค์Œ์— ์ธ์ฝ”๋” ๊ฐ’์„ ์ฐธ์กฐํ•œ๋‹ค.

๋˜ํ•œ, ์ฒซ๋ฒˆ์งธ Attention์€ ๋งˆ์Šคํ‚น ์ฒ˜๋ฆฌ๊ฐ€ ์žˆ๋‹ค.
(์ธ์ฝ”๋”๋Š” ์–‘๋ฐฉํ–ฅ, ๋””์ฝ”๋”๋Š” ๋‹จ๋ฐฉํ–ฅ)
(์ž…๋ ฅ์€ ๋ณ‘๋ ฌ๋กœ ํ•˜๋”๋ผ๋„, ์ถœ๋ ฅ์€ ์ˆœ์ฐจ์ ์œผ๋กœ ํ•ด์•ผ ํ•œ๋‹ค.)
์ž๊ธฐ ์•ž์— ์žˆ๋Š” ์• ๋“ค๊นŒ์ง€๋งŒ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ์„ ํ•˜๊ณ ,
๋’ค์— ๋‚˜์˜ฌ ํ† ํฐ์€ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ์„ ํ•˜์ง€ ์•Š๋Š”๋‹ค. (์šฐ์ธก ํ•˜๋‹จ ๋…น์ƒ‰ ๊ณ„๋‹จ ์ฐธ๊ณ )


๋‘๋ฒˆ์งธ Attention - ์ธ์ฝ”๋” ๋””์ฝ”๋” ์–ดํ…์…˜: 
Query = ๋””์ฝ”๋” ํ–‰๋ ฌ
key, Value = ์ธ์ฝ”๋” ํ–‰๋ ฌ


์–ดํ…์…˜ (์…€ํ”„ ์–ดํ…์…˜, ๋‹จ์–ด์˜ ๋ฌธ๋งฅ ํŒŒ์•…)


ํฌ์ง€์…”๋„ ์ธ์ฝ”๋”ฉ

ex) 
1. ์ž„๋ฒ ๋”ฉ: ๋‚˜๋Š” ์˜ค๋Š˜ ํ•™๊ต์— ๊ฐ„๋‹ค -> ๊ฐ ๋‹จ์–ด๋ฅผ ๋ฒกํ„ฐํ™”

2. ์œ„์น˜์  ๊ณ ์œ ๊ฐ’, PE (๋™์ผํ•œ n์ฐจ์›)
  • ์–ด๋–ค ํŠน์ • ํ† ํฐ์„ ๋‹ค๋ฅธ ํ† ํฐ์œผ๋กœ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋‹ค.

  • pos๋Š” ์œ„์น˜ ๊ฐ’: 0, 1, 2, 3
  • i๋Š” ๋‹จ์–ด ํ•˜๋‚˜ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ์˜ ์ฐจ์› ๊ฐ’
  • ์™œ sin, cos ์ธ๊ฐ€? ๊ฒน์น˜์ง€์ง€๋งŒ ์•Š์œผ๋ฉด ๋˜์ง€ ์•Š์„๊นŒ? ๋žœ๋ค ๋Œ๋ ค๋„ ๋ ํ…๋ฐ?
    • ๊ฒฐ๊ณผ ๊ฐ’์ด ์„ ํ˜•(๋‹จ์ˆœํ•œ) ๊ด€๊ณ„๊ฐ€ ๋  ์ˆ˜ ์žˆ๋„๋ก,
      ์„ ํ˜•์œผ๋กœ ์˜ˆ์ธก ๊ฐ€๋Šฅํ•œ ๋น„์„ ํ˜•์„ ์ฃผ์ž…ํ•œ๋‹ค. ํ•™์Šต์‹œ ๋ช…๋ฃŒํ•จ์„ ์ฃผ๊ธฐ ์œ„ํ•ด.
    • sin(์ง์ˆ˜) ๊ณผ cos(ํ™€์ˆ˜) ๋‘ ๊ฐœ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๊ฒน์น  ์ผ์ด ์ƒ์‡„ํ•œ๋‹ค.
      • 0(2i): sin(... i=0)
      • 1(2i+1): cos(...i=0)
      • 2: sin(....i=1์ด๋ฏ€๋กœ 2)
      • 3: cos(... i=1์ด๋ฏ€๋กœ 2)

  • 10000์€ ๋ฌด์—‡์ธ๊ฐ€? ๋‹จ์ˆœ ์‹คํ—˜ ๊ฐ’.
  • ์ง€์ˆ˜๋กœ ์ฒ˜๋ฆฌํ•œ ์ด์œ ๋Š”?
    • ์ธ๊ฐ„์ด ์—ฐ๊ตฌํ•œ ๊ฐ’ ์ค‘, ์ต์ˆ™ํ•œ ๊ฐ’์€ ๋กœ๊ทธ ๊ฐ’. ๊ทธ๋ž˜์„œ ๋กœ๊ทธ์˜ ์—ญํ•จ์ˆ˜์ธ ์ง€์ˆ˜๋ฅผ ์‚ฌ์šฉ

3. ๋‹จ์–ด ๋ฒกํ„ฐํ™”(n์ฐจ์›) + ์œ„์น˜๊ฐ’(๋™์ผํ•œ n์ฐจ์›) = ์ธํ’‹(n์ฐจ์›)

BERT & GPT



BERT์˜ ํ•ต์‹ฌ์€ ๋ถ„๋ฅ˜, GPT๋Š” ์ƒ์„ฑํ˜• ์ฑ—๋ด‡์ด ๋ชฉ์ ์ด์—ˆ๋‹ค.


๋ถ„๋ฅ˜ ๋ ˆ์ด๋ธ” ์—†์ด ๋ฌธ์„œ๋งŒ ๊ฐ€์ง€๊ณ  ํ•™์Šตํ•œ๋‹ค.


GPT
  • ๋””์ฝ”๋”๋งŒ ์‚ฌ์šฉํ•˜๋Š” ๋‹จ๋ฐฉํ–ฅ

BERT
  • ์ธ์ฝ”๋”๋งŒ ์‚ฌ์šฉํ•˜๋Š” ์–‘๋ฐฉํ–ฅ
  • ๋ฌธ์žฅ์˜ ํŠน์ • ๋‹จ์–ด์— ๋งˆ์Šคํฌ ์ฒ˜๋ฆฌ๋ฅผ ํ•ด์„œ, ๋‹จ์–ด๋ฅผ ๋งž์ถ”๋Š” ํ•™์Šต์„ ํ•œ๋‹ค.
  • ์„œ๋ธŒ์›Œ๋“œ ํ† ํฌ๋‚˜์ด์ € ์‚ฌ์šฉ
  • NSP: ๋‹ค์Œ ๋ฌธ์žฅ ์˜ˆ์ธก. ๋‘ ๊ฐœ์˜ ๋ฌธ์žฅ์„ ์ค€ ํ›„์— ์ด ๋ฌธ์žฅ์ด ์ด์–ด์ง€๋Š” ๋ฌธ์žฅ์ธ์ง€ ์•„๋‹Œ์ง€ ๋งž์ถ”๋Š” ๋ฐฉ์‹


์‹ค์Šต

# ํ† ํฐํ™”๋ฅผ ์œ„ํ•œ ํ† ํฌ๋‚˜์ด์ €
from transformers import BertTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

์ „์ฒ˜๋ฆฌ

์งˆ๋ฌธ๊ณผ ๋‹ต์ด ์Œ์„ ์ด๋ฃจ๋Š” ํ•™์Šต๋ฐ์ดํ„ฐ

ํ† ํฐ ๊ธธ์ด


ํŒจ๋”ฉ ๋ฐ ์ธ์ฝ”๋”ฉ
# ์งˆ๋ฌธ ํŒจ๋”ฉ ๋ฐ ์ธ์ฝ”๋”ฉ
def src_encoding(tokens):
    # (src_length - ํ† ํฐ ์ˆ˜) ๋งŒํผ ํŒจ๋”ฉํ† ํฐ ์ถ”๊ฐ€ - ํŒจ๋”ฉํ† ํฐ์„ ์ถ”๊ฐ€ํ•˜์—ฌ ์ตœ๋Œ€๊ธธ์ด๋กœ ๋งž์ถค
    tokens = tokens + ['<PAD>'] *  (src_length - len(tokens))

    # ์ธ์ฝ”๋”ฉ ๋œ ์ˆซ์ž๋ฅผ ๋‹ด์•„ ๋‘˜ ๋ฆฌ์ŠคํŠธ
    index_sequences = []

    # ๋ฌธ์žฅ์—์„œ ํ† ํฐ์„ ๊บผ๋‚ด์˜ค๋ฉฐ
    for word in tokens:
      try: # ํ† ํฐ ์ธ์ฝ”๋”ฉ (๋‹จ์–ด ์‚ฌ์ „์— ์—†๋Š” ํ† ํฐ์ด ๋“ค์–ด์˜ค๋ฉด except๋กœ)
          index_sequences.append(vocab[word])
      except KeyError: # ๋‹จ์–ด ์‚ฌ์ „์— ์—†๋Š” ํ† ํฐ์ด ๋“ค์–ด์˜ค๋ฉด '<UNK>' ํ† ํฐ์˜ ์ˆซ์ž๋กœ ๋ณ€ํ™˜
          index_sequences.append(vocab['<UNK>'])

    return index_sequences

# ๋‹ต๋ณ€๋ฌธ ํŒจ๋”ฉ ๋ฐ ์ธ์ฝ”๋”ฉ
def trg_encoding(tokens):
    # (src_length - ํ† ํฐ ์ˆ˜) ๋งŒํผ ํŒจ๋”ฉํ† ํฐ ์ถ”๊ฐ€
    tokens = tokens + ['<PAD>'] *  (trg_length - len(tokens))

    # ์ธ์ฝ”๋”ฉ ๋œ ์ˆซ์ž๋ฅผ ๋‹ด์•„ ๋‘˜ ๋ฆฌ์ŠคํŠธ
    index_sequences = []

    # ๋ฌธ์žฅ์—์„œ ํ† ํฐ์„ ๊บผ๋‚ด์˜ค๋ฉฐ
    for word in tokens:
      try: # ํ† ํฐ ์ธ์ฝ”๋”ฉ (๋‹จ์–ด ์‚ฌ์ „์— ์—†๋Š” ํ† ํฐ์ด ๋“ค์–ด์˜ค๋ฉด except๋กœ)
          index_sequences.append(vocab[word])
      except KeyError: # ๋‹จ์–ด ์‚ฌ์ „์— ์—†๋Š” ํ† ํฐ์ด ๋“ค์–ด์˜ค๋ฉด '<UNK>' ํ† ํฐ์˜ ์ˆซ์ž๋กœ ๋ณ€ํ™˜
          index_sequences.append(vocab['<UNK>'])

    return index_sequences


๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ƒ์„ฑ -> ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์—์„œ ์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์™€์„œ ํ…์„œ๋กœ ๋ณ€ํ™˜




ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์˜ ๊ตฌ์ถ• ์ˆœ์„œ

1. ํฌ์ง€์…”๋„ ์ธ์ฝ”๋”ฉ ํด๋ž˜์Šค
2. ๋งˆ์Šคํ‚น ํ•จ์ˆ˜
3. ๋ชจ๋ธ ์„ค๊ณ„


# ์ˆ˜ํ•™ ์—ฐ์‚ฐ์„ ์œ„ํ•œ math ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
import math

class PositionalEncoding(nn.Module):
    # ํด๋ž˜์Šค ์ƒ์„ฑ ์‹œ emb_size, dropout, maxlen์„ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์Œ
    def __init__(self, emb_size, dropout, maxlen=5000):
        super(PositionalEncoding, self).__init__()

        # (maxlen * emb_size)ํฌ๊ธฐ์˜ PostionalEncoding ํ–‰๋ ฌ ์ƒ์„ฑ ํ›„ ์›๋ณธ๊ณผ ํƒ€๊ฒŸ์ž…๋ ฅ๊ฐ’์— ๋”ํ•ด์คŒ

        # den :10000^(2i/d_model) ๊ตฌํ˜„
        den = 10000 ** (torch.arange(0, emb_size, 2) / emb_size)

        # position
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)

        # ์˜ ํ–‰๋ ฌ ์ƒ์„ฑ(maxlen, emb_size)
        pos_embedding = torch.zeros((maxlen, emb_size))

        # ์˜ ํ–‰๋ ฌ์˜ ์ง์ˆ˜ ์—ด์€ ์‚ฌ์ธํ•จ์ˆ˜ ์ ์šฉ(0,2,4,6 .....)
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        # ์˜ ํ–‰๋ ฌ์˜ ํ™€์ˆ˜ ์—ด์€ ์ฝ”์‚ฌ์ธ ํ•จ์ˆ˜ ์ ์šฉ(1,3,5,7)
        pos_embedding[:, 1::2] = torch.cos(pos * den)

        # pistional embedding์— ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€(1, maxlen, embsize)
        pos_embedding = pos_embedding.unsqueeze(0)

        # ๋…ผ๋ฌธ์—์„œ๋Š” ํฌ์ง€์…”๋„์ธ์ฝ”๋”ฉ์„ ๊ฑฐ์นœ ํ›„ ๋“œ๋กญ์•„์›ƒ์„ ์ ์šฉ
        self.dropout = nn.Dropout(dropout)

        # PyTorch์˜ nn.Module์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ธฐ๋Šฅ์œผ๋กœ, ๋ชจ๋ธ์˜ ์ƒํƒœ์— ํฌํ•จ๋˜์ง€๋งŒ ํ•™์Šต๋˜์ง€ ์•Š๋Š” ํ…์„œ๋ฅผ ๋“ฑ๋กํ•  ๋•Œ ์‚ฌ์šฉ
        # PositionalEncoding ๋ฒกํ„ฐ๋Š” ๋ณ€ํ•˜์ง€ ์•Š๊ณ  ๊ณ ์ •
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        # ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ์™€(batch, seq_length, embsize) PE๋ฒกํ„ฐ(1, maxlen, embsize) ๋”ํ•จ
        # ๋ฐฐ์น˜์˜ ๊ฐ ๋ฐ์ดํ„ฐ๋“ค์—๋Œ€ํ•ด ๊ฐ๊ฐ PE ๋”ํ•จ
        # ์ดํ›„ ๋“œ๋กญ์•„์›ƒ ์ ์šฉ
        return self.dropout(token_embedding + self.pos_embedding[:, :token_embedding.size(1), :])

class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, emb_size,
                 nhead, src_vocab_size, tgt_vocab_size, dim_feedforward=512,
                 dropout=0.1):
        super(Seq2SeqTransformer, self).__init__()
        # ๋ชจ๋ธ์—์„œ ์‚ฌ์šฉํ•  ๋ ˆ์ด์–ด ์ •์˜

        # ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด
        self.emb_size = emb_size
        self.embedding = nn.Embedding(vocab_size, emb_size)

        # ํฌ์ง€์…”๋„ ์ธ์ฝ”๋”ฉ ๋ ˆ์ด์–ด
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

        # ํŠธ๋žœ์Šคํฌ๋จธ ๋ ˆ์ด์–ด(์ธ์ฝ”๋”+๋””์ฝ”๋”)
        # d_model : ์ž…์ถœ๋ ฅ ์ฐจ์› ์ˆ˜, num_encoder(decoder)_layers : ์ธ์ฝ”๋”(๋””์ฝ”๋”)์˜ ์ธต ์ˆ˜)
        # dim_feedforward : FFNN์˜ ์ฐจ์› ์ˆ˜(attention ๊ฑฐ์นœ ํ›„ FFNN์—์„œ dmodel -> dim_feedforward -> dmodel)
        self.transformer = nn.Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)

        # ์ถœ๋ ฅ ๋ ˆ์ด์–ด(๋‹จ์–ด์‚ฌ์ „์— ์žˆ๋Š” ๋‹จ์–ด๋กœ ์ถœ๋ ฅ๋  ์ˆ˜ ์žˆ๋„๋ก)
        self.fc = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
        # ์งˆ๋ฌธ, ๋‹ต๋ณ€๋ฌธ ์ž„๋ฒ ๋”ฉ
        # ๋…ผ๋ฌธ์—์„œ๋Š” ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ์˜ ์Šค์ผ€์ผ์„ ์กฐ์ •ํ•˜์—ฌ ์ดˆ๊ธฐํ™” ๊ณผ์ •์—์„œ์˜ ๋ถˆ์•ˆ์ •์„ฑ์„ ์ค„์ด๊ธฐ ์œ„ํ•ด ์ž„๋ฒ ๋”ฉ ํฌ๊ธฐ์˜ ์ œ๊ณฑ๊ทผ์„ ๊ณฑํ•จ
        src, tgt = src.long(), tgt.long()
        src = self.embedding(src) * math.sqrt(self.emb_size)
        tgt = self.embedding(tgt) * math.sqrt(self.emb_size)

        # ํฌ์ง€์…”๋„ ์ธ์ฝ”๋”ฉ ์ ์šฉ
        src_emb = self.positional_encoding(src)
        tgt_emb = self.positional_encoding(tgt)

        # ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ ํ†ต๊ณผ
        # src_emb : ์งˆ๋ฌธ ๋ฐ์ดํ„ฐ, tgt_emb : ๋‹ต๋ณ€๋ฌธ ๋ฐ์ดํ„ฐ, src_mask : ์งˆ๋ฌธ ๋งˆ์Šคํฌ, tgt_mask : ๋‹ต๋ณ€๋ฌธ ๋งˆ์Šคํฌ(์–ดํ…์…˜ ์‹œ ๋ฏธ๋ž˜ ์‹œ์  ๋ชป๋ณด๋„๋ก)
        # src(tgt)_padding_mask : ์งˆ๋ฌธ(๋‹ต๋ณ€๋ฌธ)์ด ์–ดํ…์…˜ ์‹œ ํŒจ๋”ฉํ† ํฐ์€ ์ ์šฉ๋˜์ง€ ์•Š๋„๋ก ํŒจ๋”ฉ ํ† ํฐ์— ๋งˆ์Šคํ‚น-> self-atteniton์— ์ ์šฉ
        # memory_key_padding_mask : ๋””์ฝ”๋”๊ฐ€ ์ธ์ฝ”๋”์˜ ์ถœ๋ ฅ ์ •๋ณด ํ™œ์šฉ ์‹œ ํŒจ๋”ฉ๋œ ๋ถ€๋ถ„ ํ™œ์šฉํ•˜์ง€ ์•Š๋„๋ก ๋งˆ์Šค -> encoder-decoder attention์— ์ ์šฉ
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, # None ๋ถ€๋ถ„์€ memory ๋งˆ์Šคํฌ๋กœ ๋ณดํ†ต None์œผ๋กœ ๋‘ 
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        outs = self.fc(outs)
        return outs # (๋ฐฐ์น˜์‚ฌ์ด์ฆˆ, ์ถœ๋ ฅ์‹œํ€€์Šค ๊ธธ์ด, ๋‹จ์–ด ์‚ฌ์ „ ์ˆ˜)


    # ์˜ˆ์ธก ์‹œ ์ ์šฉํ•  ์ธ์ฝ”๋”ฉ ํ•จ์ˆ˜
    def encode(self, src, src_mask):
        # ์˜ˆ์ธก ์‹œ ์‚ฌ์šฉํ•  ์ธ์ฝ”๋” ํ†ต๊ณผ(์งˆ๋ฌธ -> ์ž„๋ฒ ๋”ฉ -> ํฌ์ง€์…”๋„์ธ์ฝ”๋”ฉ -> ํŠธ๋žœ์Šคํฌ๋จธ ์ธ์ฝ”๋” ํ†ต๊ณผ)
        src = src.long()
        src = self.embedding(src) * math.sqrt(self.emb_size)
        src_emb = self.positional_encoding(src)

        # ์‹ค์ œ ์˜ˆ์ธก ์‹œ ํŒจ๋”ฉ์„ ๋„ฃ์ง€ ์•Š๊ธฐ์— ํŒจ๋”ฉ ๋งˆ์Šคํฌ๋Š” ์ ์šฉ X
        outs = self.transformer.encoder(src_emb, src_mask)
        return outs

    # ์˜ˆ์ธก ์‹œ ์ ์šฉํ•  ๋””์ฝ”๋”ฉ ํ•จ์ˆ˜
    def decode(self, tgt, memory, tgt_mask):
        # ์—์ธก ์‹œ ๋””์ฝ”๋” ํ†ต๊ณผ(๋‹ต๋ณ€๋ฌธ -> ์ž„๋ฒ ๋”ฉ -> ํฌ์ง€์…”๋„์ธ์ฝ”๋”ฉ -> ํŠธ๋žœ์Šคํฌ๋จธ ๋””์ฝ”๋” ํ†ต๊ณผ)
        # ์˜ˆ์ธก ์‹œ tgt๋Š” sos ํ† ํฐ, memory๋Š” ์ธ์ฝ”๋”์˜ ์ถœ๋ ฅ(๋งค ์ธต์—์„œ ์ธ์ฝ”๋” ๋””์ฝ”๋” ์–ดํ…์…˜์— ํ™œ์šฉ)
        tgt = tgt.long()
        tgt = self.embedding(tgt) * math.sqrt(self.emb_size)
        tgt_emb = self.positional_encoding(tgt)

        outs = self.transformer.decoder(tgt_emb, memory, tgt_mask)
        return outs

# ๋ฏธ๋ž˜ ์‹œ์  ๋งˆ์Šคํฌ๋ฅผ ์œ„ํ•œ ํ•จ์ˆ˜ ์ƒ์„ฑ
def generate_square_subsequent_mask(sz):
    # torch.ones((sz, sz) : (sz, sz)ํฌ๊ธฐ์˜ 1๋กœ ์ฑ„์›Œ์ง„ ํ–‰๋ ฌ ์ƒ์„ฑ
    # torch.triu() : ๋Œ€๊ฐ์„  ๊ธฐ์ค€์œผ๋กœ ์•„๋ž˜์˜ ์›์†Œ๋ฅผ 0์œผ๋กœ ๋ณ€๊ฒฝ
    # 1 1 1
    # 0 1 1
    # 0 0 1

    # ์ „์น˜
    # 1 0 0
    # 1 1 0
    # 1 1 1
    mask = torch.triu(torch.ones((sz, sz), device=device)).transpose(0, 1)

    # 0์ธ ๋ถ€๋ถ„์„ ์Œ์˜ ๋ฌดํ•œ ๊ฐ’์œผ๋กœ, 1์ธ ๋ถ€๋ถ„์„ 0.0์œผ๋กœ ๋ณ€๊ฒฝ
    # ํƒ€์ผ“๋ฌธ์˜ ์‹œํ€€์Šค ๊ธธ์ด ๋งŒํผ์˜ ๋งˆ์Šคํฌ ํ–‰๋ ฌ์„ ์ƒ์„ฑํ•˜์—ฌ self-attention ์‹œ ๋ฏธ๋ž˜ ์‹œ์ ์˜ ๊ฐ’๋“ค์€ ์Œ์˜ ๋ฌดํ•œ๋Œ€๋กœ ๊ฐˆ ์ˆ˜ ์žˆ๋„๋ก
    # ์Œ์˜ ๋ฌดํ•œ๋Œ€ ๊ฐ’์€ softmax ํ†ต๊ณผ ์‹œ 0์œผ๋กœ ์ˆ˜๋ ด
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# ์ „์ฒด ๋งˆ์Šคํฌ(ํŒจ๋”ฉ, ๋ฏธ๋ž˜์‹œ์ )๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜
def create_mask(src, tgt):
    # ์งˆ๋ฌธ๊ณผ ๋‹ต๋ณ€๋ฌธ์˜ ์‹œํ€€์Šค ๊ธธ์ด ํ™•์ธ
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    # ๋ฏธ๋ž˜์‹œ์  ๋งˆ์Šคํฌ
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)

    # ์งˆ๋ฌธ์€ ๋งˆ์Šคํ‚น์„ ์ ์šฉํ•˜์ง€ ์•Š๊ธฐ์— False๋กœ ์ด๋ฃจ์–ด์ง„ ํ–‰๋ ฌ ์ƒ์„ฑ
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    # ํŒจ๋”ฉ๋งˆ์Šคํฌ ์ƒ์„ฑ(์ž…๋ ฅ ๋ฐ์ดํ„ฐ ํ–‰๋ ฌ์—์„œ ์ž…๋ ฅ ํ† ํฐ์ด 0(<PAD>)์ธ ๋ถ€๋ถ„์„ True๋กœ ๋‚˜๋จธ์ง€๋Š” False์ธ ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜)
    src_padding_mask = (src == 0)
    tgt_padding_mask = (tgt == 0)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask






๋Œ“๊ธ€ ์“ฐ๊ธฐ

๋‹ค์Œ ์ด์ „