Attention
- 입력 토큰의 길이 (seq의 길이)가 길어지면, Context Vector가 장기기억을 감당하는 데 한계가 있다. 그러므로 미리 중요한 정보라는 힌트를 알려준다. (ex: 지금 중요한 내용이야, 또는 지금은 대충 들어 등)
- 동작 원리
- 기존 seq2seq:
- context 벡터 + 이전 토큰의 결과(번역어)
- Attention:
- + 이전 토큰의 결과(번역어)와 연관성 있는 입력 토큰 정보를 입력함
- 연관성?
- 이전 토크의 결과(번역어)와 모든 입력 토큰 정보의 유사도 계산을 한다.
- Attention Value
- Attention(Q, K, V)의 결과
- Q: query, t시점(t시퀀스)의 디코더 LTSTM의 은닉값 (벡터 형태의 hidden)
- K: key, 모든 시점의 인코더 셀의 은닉값 (벡터 형태의 hidden)
- 하나의 Q가 각각의 K와 곱해짐. (Q*K1), (Q*K2), (Q*K3), (Q*K4)
곱 뿐만 아니라, Q*K 연산 종류는 다양하다. - 3의 결과에 SoftMax 결과 값이 나옴 (각 입력 토큰의 가중치)
- V: =K
- 3의 결과 (Q*Kn) 에 다시 V(=K)를 곱함. (실제값 * 가중치)
- 최종 Value = (실제값1*가중치1) + (실제값2*가중치2) ....
- 결국, Q(t시점의 디코더 은닉값)에 Attention Value를 Concatenate(연결)한 뒤에 fc(완전연결계층)의 입력값으로 사용한다.
Attention 실습
Seq2Seq 부분만 정리
import torch
import torch.nn as nn
import torch.optim as optim
embedding_dim = 256
hidden_units = 512
인코더
class Encoder(nn.Module):
def __init__(self, src_vocab_size, embedding_dim, hidden_units):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(src_vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True)
def forward(self, x):
x = self.embedding(x)
# x.shape == (batch_size, seq_len, embedding_dim)
outputs, (hidden, cell) = self.lstm(x)
# hidden.shape == (1, batch_size, hidden_units), cell.shape == (1, batch_size, hidden_units)
return outputs, hidden, cell
디코더 (Attention이 추가됐다.)
import torch
import torch.nn as nn
class Decoder(nn.Module):
def __init__(self, tar_vocab_size, embedding_dim, hidden_units):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(tar_vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True)
self.fc = nn.Linear(hidden_units + hidden_units, tar_vocab_size) # 컨텍스트 벡터와 결합된 상태로 출력하기 위해 크기를 조정
self.softmax = nn.Softmax(dim=-1) # Attention weights 계산
def forward(self, x, encoder_outputs, hidden, cell):
# x.shape: (batch_size, target_seq_len)
batch_size = x.shape[0]
seq_len = x.shape[1]
# 임베딩을 통해 입력 변환
x = self.embedding(x) # x.shape: (batch_size, target_seq_len, embedding_dim)
# 전체 출력 저장할 텐서 초기화
outputs = torch.zeros(batch_size, seq_len, self.fc.out_features).to(x.device)
for t in range(seq_len):
# LSTM 통과 # (배치, 시퀀스길이, 임베딩차원)
# x_t = x[:, t, :] -> (배치, 1, 임베딩차원)
x_t = x[:, t, :].unsqueeze(1) # 현재 타임스텝의 입력 (batch_size, 1, embedding_dim)
output, (hidden, cell) = self.lstm(x_t, (hidden, cell)) # output.shape: (batch_size, 1, hidden_units)
# Attention 계산
# hidden : (층수, batch, hidden_units)
hidden_current = hidden[-1].unsqueeze(1) # hidden_current.shape: (batch_size), (1, hidden_units)
# encoder_outputs.shape : (배치사이즈) (시퀀스길이, hidden_units)
encoder_outputs_permuted = encoder_outputs.permute(0, 2, 1) # (batch_size, hidden_units, src_seq_len)
# 행렬을 회전함, 행렬곱을 위해서
# 어텐션 스코어 계산
# 배치단위로 (1,hidden) * (hidden, src_seq_len) -> (1, src_seq_len)
attention_scores = torch.bmm(hidden_current, encoder_outputs_permuted) # (batch_size, 1, src_seq_len)
# bmm 행렬곱
# 어텐션 가중치 계산
attention_weights = self.softmax(attention_scores) # (batch_size, 1, src_seq_len)
# 컨텍스트 벡터 계산
# 배치단위로 (1,src_seq_len) * (src_seq_len, hidden_units)
context_vector = torch.bmm(attention_weights, encoder_outputs) # (batch_size, 1, hidden_units)
# bmm 과정에서 행렬곱 결과가 sum 된다.
context_vector = context_vector.squeeze(1) # (batch_size, hidden_units)
# LSTM 출력과 컨텍스트 벡터 결합
output_combined = torch.cat((output.squeeze(1), context_vector), dim=1) # (batch_size, hidden_units * 2)
# FC 레이어를 통해 최종 출력 계산
output = self.fc(output_combined) # output.shape: (batch_size, tar_vocab_size)
outputs[:, t, :] = output
return outputs, hidden, cell
Seq2Seq : 인코더와 디코더 결합
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src, trg):
encoder_outputs, hidden, cell = self.encoder(src)
output, _, _ = self.decoder(trg, encoder_outputs, hidden, cell)
return output
encoder = Encoder(src_vocab_size, embedding_dim, hidden_units)
decoder = Decoder(tar_vocab_size, embedding_dim, hidden_units)
model = Seq2Seq(encoder, decoder)
loss_function = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters())
Tags:
AI개발_교육