Seq2Seq와 Attention

Attention

  • 입력 토큰의 길이 (seq의 길이)가 길어지면, Context Vector가 장기기억을 감당하는 데 한계가 있다. 그러므로 미리 중요한 정보라는 힌트를 알려준다. (ex: 지금 중요한 내용이야, 또는 지금은 대충 들어 등)

  • 동작 원리
    • 기존 seq2seq:
      • context 벡터 + 이전 토큰의 결과(번역어)
    • Attention:
      • + 이전 토큰의 결과(번역어)와 연관성 있는 입력 토큰 정보를 입력함
    • 연관성?
      • 이전 토크의 결과(번역어)와 모든 입력 토큰 정보의 유사도 계산을 한다.
  • Attention Value
    • Attention(Q, K, V)의 결과
      1. Q: query, t시점(t시퀀스)의 디코더 LTSTM의 은닉값 (벡터 형태의 hidden)
      2. K: key, 모든 시점의 인코더 셀의 은닉값 (벡터 형태의 hidden)
      3. 하나의 Q가 각각의 K와 곱해짐. (Q*K1), (Q*K2), (Q*K3), (Q*K4)
        곱 뿐만 아니라, Q*K 연산 종류는 다양하다.
      4. 3의 결과에 SoftMax 결과 값이 나옴 (각 입력 토큰의 가중치)
      5. V: =K
      6. 3의 결과 (Q*Kn) 에 다시 V(=K)를 곱함. (실제값 * 가중치)
      7. 최종 Value = (실제값1*가중치1) + (실제값2*가중치2) .... 

    • 결국, Q(t시점의 디코더 은닉값)에 Attention Value를 Concatenate(연결)한 뒤에 fc(완전연결계층)의 입력값으로 사용한다.

Attention 실습

Seq2Seq 부분만 정리

import torch
import torch.nn as nn
import torch.optim as optim

embedding_dim = 256
hidden_units = 512

인코더
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embedding_dim, hidden_units):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(src_vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True)

    def forward(self, x):
        x = self.embedding(x)
        # x.shape == (batch_size, seq_len, embedding_dim)
        outputs, (hidden, cell) = self.lstm(x)
        # hidden.shape == (1, batch_size, hidden_units), cell.shape == (1, batch_size, hidden_units)
        return outputs, hidden, cell

디코더 (Attention이 추가됐다.)
import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, tar_vocab_size, embedding_dim, hidden_units):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(tar_vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True)
        self.fc = nn.Linear(hidden_units + hidden_units, tar_vocab_size)  # 컨텍스트 벡터와 결합된 상태로 출력하기 위해 크기를 조정
        self.softmax = nn.Softmax(dim=-1)  # Attention weights 계산

    def forward(self, x, encoder_outputs, hidden, cell):
        # x.shape: (batch_size, target_seq_len)
        batch_size = x.shape[0]
        seq_len = x.shape[1]

        # 임베딩을 통해 입력 변환
        x = self.embedding(x)  # x.shape: (batch_size, target_seq_len, embedding_dim)

        # 전체 출력 저장할 텐서 초기화
        outputs = torch.zeros(batch_size, seq_len, self.fc.out_features).to(x.device)

        for t in range(seq_len):
            # LSTM 통과 # (배치, 시퀀스길이, 임베딩차원)
            # x_t = x[:, t, :] -> (배치, 1, 임베딩차원)

            x_t = x[:, t, :].unsqueeze(1)  # 현재 타임스텝의 입력 (batch_size, 1, embedding_dim)

            output, (hidden, cell) = self.lstm(x_t, (hidden, cell))  # output.shape: (batch_size, 1, hidden_units)

            # Attention 계산
            # hidden : (층수, batch, hidden_units)
            hidden_current = hidden[-1].unsqueeze(1)  # hidden_current.shape: (batch_size), (1, hidden_units)
            # encoder_outputs.shape  : (배치사이즈) (시퀀스길이, hidden_units)
            encoder_outputs_permuted = encoder_outputs.permute(0, 2, 1)  # (batch_size, hidden_units, src_seq_len)
                                                                         # 행렬을 회전함, 행렬곱을 위해서
            # 어텐션 스코어 계산
            # 배치단위로 (1,hidden) * (hidden, src_seq_len) -> (1, src_seq_len)
            attention_scores = torch.bmm(hidden_current, encoder_outputs_permuted)  # (batch_size, 1, src_seq_len)
                                                                                    # bmm 행렬곱
            # 어텐션 가중치 계산
            attention_weights = self.softmax(attention_scores)  # (batch_size, 1, src_seq_len)

            # 컨텍스트 벡터 계산
            # 배치단위로 (1,src_seq_len) * (src_seq_len, hidden_units)
            context_vector = torch.bmm(attention_weights, encoder_outputs)  # (batch_size, 1, hidden_units)  
                                                                            # bmm 과정에서 행렬곱 결과가 sum 된다.
            context_vector = context_vector.squeeze(1)  # (batch_size, hidden_units)

            # LSTM 출력과 컨텍스트 벡터 결합
            output_combined = torch.cat((output.squeeze(1), context_vector), dim=1)  # (batch_size, hidden_units * 2)

            # FC 레이어를 통해 최종 출력 계산
            output = self.fc(output_combined)  # output.shape: (batch_size, tar_vocab_size)
            outputs[:, t, :] = output

        return outputs, hidden, cell


Seq2Seq : 인코더와 디코더 결합
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg):
        encoder_outputs, hidden, cell = self.encoder(src)
        output, _, _ = self.decoder(trg, encoder_outputs, hidden, cell)
        return output

encoder = Encoder(src_vocab_size, embedding_dim, hidden_units)
decoder = Decoder(tar_vocab_size, embedding_dim, hidden_units)
model = Seq2Seq(encoder, decoder)

loss_function = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters())








댓글 쓰기

다음 이전