논문 리뷰

BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (code review)

해파리냉채무침 2024. 3. 7. 17:38

BERT 코드 찾아보니까 huggingface나 berttokenizer 자체를 임포트 해서 진행하는 예시가 많았다.

파이토치로 바닥부터 구현하고 싶은데 튜토리얼이 따로 없어서 깃허브 보고 코드를 리뷰해보려고 한다.(내가 못찾는거 일수도...)

모델 아키텍처의 구조를 파악하고 공부하는것이 코드리뷰의 목적이기 때문에 모델파일 위주로 리뷰해보려고 한다. 

https://github.com/codertimo/BERT-pytorch 

 

GitHub - codertimo/BERT-pytorch: Google AI 2018 BERT pytorch implementation

Google AI 2018 BERT pytorch implementation. Contribute to codertimo/BERT-pytorch development by creating an account on GitHub.

github.com

위의 깃허브 보고 리뷰하였다.

 

dataset.py

from torch.utils.data import Dataset
import tqdm
import torch
import random


class BERTDataset(Dataset):
    def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
        self.vocab = vocab #단어
        self.seq_len = seq_len #시퀀스 길이

        self.on_memory = on_memory #메모리 전부 로드할지 여부
        self.corpus_lines = corpus_lines #말뭉치 길이 수
        self.corpus_path = corpus_path #텍스트 파일 경로
        self.encoding = encoding #인코딩만 

        with open(corpus_path, "r", encoding=encoding) as f:# 데이터를 로드하거나 데이터의 총 라인 수를 계산
            if self.corpus_lines is None and not on_memory: 
                for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
                    self.corpus_lines += 1

            if on_memory: #on_memory가 True이면, 파일의 모든 라인을 메모리에 로드하고, 그렇지 않으면 파일의 총 라인 수만 계산 
                self.lines = [line[:-1].split("\t")
                              for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
                self.corpus_lines = len(self.lines)

        if not on_memory:#데이터를 읽기 위한 파일 객체를 두 개 열어줌 
            self.file = open(corpus_path, "r", encoding=encoding)
            self.random_file = open(corpus_path, "r", encoding=encoding)#랜덤한 위치의 데이터를 읽음

            for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
                self.random_file.__next__()

    def __len__(self):
        return self.corpus_lines #말뭉치 문장수의 갯수 가져옴 

    def __getitem__(self, item):
        t1, t2, is_next_label = self.random_sent(item) #두 개의 문장과 해당 문장들이 이어지는지 여부를 라벨로 가져옴 
        t1_random, t1_label = self.random_word(t1) #각 문장에서 무작위로 단어를 선택하고, 해당 단어의 라벨가져옴 
        t2_random, t2_label = self.random_word(t2)

        # [CLS] tag = SOS tag, [SEP] tag = EOS tag
        t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index] # 특수 토큰([CLS], [SEP])을 추가
        t2 = t2_random + [self.vocab.eos_index]

        t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index] #문장 라벨에도 패딩을 추가
        t2_label = t2_label + [self.vocab.pad_index]

        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] #각 세그먼트에 대한 라벨 생성 라벨 1and 2
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]

        padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)  #합친 후 시퀀스 길이에 맞게 패딩을 추가

        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value) for key, value in output.items()} # Tensor로 변환하여 딕셔너리 형태로 반환

    def random_word(self, sentence):
        tokens = sentence.split() #입력 문장을 공백을 기준으로 토큰화
        output_label = []

        for i, token in enumerate(tokens): #각 토큰에 대해 15%의 확률로 세 가지 조건 중 하나를 선택
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15

                # 80% randomly change token to mask token  MASK 토큰으로 대체
                if prob < 0.8:
                    tokens[i] = self.vocab.mask_index

                # 10% randomly change token to random token 다른 랜덤 토큰으로 대체
                elif prob < 0.9:
                    tokens[i] = random.randrange(len(self.vocab))

                # 10% randomly change token to current token 해당 토큰을 그대로 둠
                else:
                    tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)

                output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) # 토큰이 MASK로 변경되거나 랜덤 토큰으로 변경된 경우, output_label에는 원래 토큰의 인덱스가 저장

            else:
                tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
                output_label.append(0) #토큰이 변경되지 않은 경우, output_label에는 0이 저장

        return tokens, output_label

    def random_sent(self, index):
        t1, t2 = self.get_corpus_line(index) #주어진 인덱스에 해당하는 두 연속된 문장 가져옴

        # output_text, label(isNotNext:0, isNext:1)
        if random.random() > 0.5: #random수가0.5보다 크면  t1과 t2를 그대로 반환하고 (isNext)
            return t1, t2, 1
        else: #0.5보다 작으면 t1과 무작위로 선택된 문장을 반환 (NotNext)
            return t1, self.get_random_line(), 0

    def get_corpus_line(self, item):
        if self.on_memory: #on_memory가 True면, 메모리에 저장된 lines 리스트에서 해당 인덱스의 문장을 반환
            return self.lines[item][0], self.lines[item][1]
        else: #on_memory가 False라면, 파일에서 다음 줄을 읽
            line = self.file.__next__()
            if line is None: #파일의 끝에 도달하면 
                self.file.close() #파일을 닫음 
                self.file = open(self.corpus_path, "r", encoding=self.encoding) #다시 열고 
                line = self.file.__next__() #다음줄 다시 읽음

            t1, t2 = line[:-1].split("\t") #\t를 시준으로 문자열을 분리하여 t1,t2얻음(마지막 문자 제외 )
            return t1, t2

    def get_random_line(self):
        if self.on_memory:# on_memory가 True라면
            return self.lines[random.randrange(len(self.lines))][1] #lines 리스트에서 무작위로 한 줄을 선택

        line = self.file.__next__() # on_memory가 False,파일에서 다음 줄을 읽음 
        if line is None:
            self.file.close()
            self.file = open(self.corpus_path, "r", encoding=self.encoding)
            for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
                self.random_file.__next__()
            line = self.random_file.__next__()
        return line[:-1].split("\t")[1]

random_sent와 get_corpus_line, get_random_line은 논문 next sentence prediction 작업에 쓰인다. 

random_word 는 Masked LM task에 해당한다. 

 

dataset

vocab.py

수치화하는데 필요한 속성들이 있다.

freqs: 토큰의 빈도수

stoi: string to index

itos: index to string

어휘객체를 만드는 인자

counter: 데이터의 빈도

max_size: 디폴트 값은 None, 단어의 최대 길이

min_freq: 디폴트값은 1, 토큰을 어휘에 포함시키는 데 필요한 최소 빈도

specials: 특수토큰, eos, padding, <unk>

vectors: 사용 가능한 사전 학습된 벡터, 사용자 정의 사전학습 벡터 또는  앞서 언급한 벡터들

unk_init: 디폴트값은 Tensor.zero_,어휘에 없는 단어벡터를 0으로 초기화

vectors_cache: 캐시된 벡터의 디렉토리

import pickle
import tqdm
from collections import Counter


class TorchVocab(object):
    def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
                 vectors=None, unk_init=None, vectors_cache=None):

        self.freqs = counter #단어 빈도
        counter = counter.copy()
        min_freq = max(min_freq, 1)

        self.itos = list(specials)  #특수토큰 리스트에 담고  인덱스에 대응하는 단어를 저장
        # frequencies of special tokens are not counted when building vocabulary
        # in frequency order
        for tok in specials:
            del counter[tok]

        max_size = None if max_size is None else max_size + len(self.itos) 

        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) #단어를 빈도수에 따라 정렬 
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) #같은 빈도수의 단어는 알파벳 순으로 정렬

        for word, freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size: #빈도수가 min_freq보다 작고, 단어장의 수가 max_size와 같은 단어
                break
            self.itos.append(word) #단어장에 추가 

        # stoi is simply a reverse dict for itos
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

        self.vectors = None #단어 벡터를 저장할 변수를 초기화
        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None

    def __eq__(self, other): #vocab에 있는 객체가 동일한지 비교 
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        if self.vectors != other.vectors:
            return False
        return True

    def __len__(self): 
        return len(self.itos)

    def vocab_rerank(self):  #stoi를 itos로 다시 생성 
        self.stoi = {word: i for i, word in enumerate(self.itos)}

    def extend(self, v, sort=False): #주어진 어휘집 v의 단어를 현재 어휘집에 추가
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:#추가된 단어는 현재 어휘집에 없는 단어만 선택
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1


class Vocab(TorchVocab):
    def __init__(self, counter, max_size=None, min_freq=1): #특수토큰 인덱스 정의 
        self.pad_index = 0
        self.unk_index = 1
        self.eos_index = 2
        self.sos_index = 3
        self.mask_index = 4
        super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"],
                         max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: 
        pass

    def from_seq(self, seq, join=False, with_pad=False):
        pass

    @staticmethod
    def load_vocab(vocab_path: str) -> 'Vocab': #pickle 파일에서 vocab 객체 불러옴 
        with open(vocab_path, "rb") as f:
            return pickle.load(f)

    def save_vocab(self, vocab_path): #주어진 경로에 vocab 객체 저장 
        with open(vocab_path, "wb") as f:
            pickle.dump(self, f)


# Building Vocab with text files
class WordVocab(Vocab):
    def __init__(self, texts, max_size=None, min_freq=1):
        print("Building Vocab")
        counter = Counter() #각 단어의 등장 빈도를 Counter 객체를 이용해서 계산 
        for line in tqdm.tqdm(texts):
            if isinstance(line, list): #각 라인이 리스트인 경우
                words = line #그대로 사용
            else:
                words = line.replace("\n", "").replace("\t", "").split() #줄바꿈과 탭 제거 , 공백 기준으로 단어 분리 

            for word in words:
                counter[word] += 1 #단어 빈도 계
        super().__init__(counter, max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): #주어진 문장을 인덱스의 시퀀스로 변환
        if isinstance(sentence, str): 
            sentence = sentence.split()

        seq = [self.stoi.get(word, self.unk_index) for word in sentence] #문장의 각 단어를 해당하는 인덱스로 변환

        if with_eos:
            seq += [self.eos_index]  # this would be index 1 , eos 토큰
        if with_sos:
            seq = [self.sos_index] + seq #sos 토큰 

        origin_seq_len = len(seq) 

        if seq_len is None: #주어진 시퀀스 길이에 맞게 패딩
            pass
        elif len(seq) <= seq_len:
            seq += [self.pad_index for _ in range(seq_len - len(seq))]
        else:
            seq = seq[:seq_len]

        return (seq, origin_seq_len) if with_len else seq

    def from_seq(self, seq, join=False, with_pad=False): #주어진 인덱스의 시퀀스를 문장으로 변환
        words = [self.itos[idx] #인덱스를 해당하는 단어로 매핑
                 if idx < len(self.itos) #인덱스가 단어 집합의 크기 내에 있는지 확인
                 else "<%d>" % idx #그렇지 않으면 인덱스 번호를 그대로 사용
                 for idx in seq
                 if not with_pad or idx != self.pad_index] #패딩 토큰을 제외

        return " ".join(words) if join else words #join이 True면 단어 리스트를 공백으로 연결하여 하나의 문자열로 반환

    @staticmethod
    def load_vocab(vocab_path: str) -> 'WordVocab': #pickle 파일에서 WordVocab객체 불러
        with open(vocab_path, "rb") as f:
            return pickle.load(f)


def build():
    import argparse

    parser = argparse.ArgumentParser() #커맨드라인 인자 파싱 
    parser.add_argument("-c", "--corpus_path", required=True, type=str) #필요한 커맨드라인 인자 추가 
    parser.add_argument("-o", "--output_path", required=True, type=str)
    parser.add_argument("-s", "--vocab_size", type=int, default=None)
    parser.add_argument("-e", "--encoding", type=str, default="utf-8")
    parser.add_argument("-m", "--min_freq", type=int, default=1)
    args = parser.parse_args()

    with open(args.corpus_path, "r", encoding=args.encoding) as f:
        vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)

    print("VOCAB SIZE:", len(vocab))
    vocab.save_vocab(args.output_path)

 

TorchVocab 클래스는 텍스트 데이터를 전처리하고, 단어를 모델이 이해할 수 있는 형태인 인덱스로 변환하는데 사용

Wordvocab 클래스는 주어진 텍스트 데이터로부터 단어집합을 생성하는 작업을 수행

 

model

attention

multi_head.py

Multihead attention은 transformer 모델에서 정의했던 모델이다.

import torch.nn as nn
from .single import Attention


class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

    def __init__(self, h, d_model, dropout=0.1): #h는 attention head의 수 , d_model은 차원수
        super().__init__()
        assert d_model % h == 0 #차원수는 head로 나눠떨어져야 한다(나머지가 0). false이면 error

        # We assume d_v always equals d_k
        self.d_k = d_model // h #차원을 head로 나눈값
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) #각각 query, key, value에 대해 transformation을 수행 , 차원수는 d_model로 동일
        self.output_linear = nn.Linear(d_model, d_model) #attention 후의 결과를  d_model 차원으로 매핑
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None): #순전파 
        batch_size = query.size(0) #입력 query의 배치 크기를 가져옴

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 
                             for l, x in zip(self.linear_layers, (query, key, value))]
        # 각각의 query, key, value에 대해 linear transformation을 수행, 결과를 h개의 head로 나눔,각 head의 차원수는 d_k
        # 2) Apply attention on all the projected vectors in batch.
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout) 
        #attention 수행, mask, dropout 적용

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
        #attention 결과 concat 최종 선형 적용

assert함수는 True일 경우 아무것도 나타나지 않고 , False일 때, assertionerror가 난다. 

 

single.py

Scaled Dot Product Attention을 구현한 것이다.

import torch.nn as nn
import torch.nn.functional as F
import torch

import math


class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """

    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) #query와 key의 행렬 곱을 계산한후, query의 마지막 차원의 크기의 제곱근으로 나눠
                 / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) #마스크가 있으면 매우큰 음수값(마이너스 무한대)를 넣어 해당 위치가 계산에서 배제되도록

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn

 

embedding

bert.py

bert embedding은 다음과 같이 이루어졌다

1. Token embedding:  임베딩 행렬

2.  positional embedding: sin, cos 를 이용해 위치 정보를 추가함

3.  segment embedding : 문장 segment 정보 추가 (sent_A:1, sent_B:2)

이 임베딩들을 다 더하면 된다.

 

__init__ 파라미터 설명

vocab_size: vocab_size의 총합

embed_size:  토큰 임베딩의 크기

dropout: 드롭아웃 rate

import torch.nn as nn
from .token import TokenEmbedding
from .position import PositionalEmbedding
from .segment import SegmentEmbedding


class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
        self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size

    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x) #모든 임베딩을 다 더하고 dropout 실행

 

positional embedding

attention is all you need의 positional encoding과 같은것 같다. 

import torch.nn as nn
import torch
import math


class PositionalEmbedding(nn.Module):

    def __init__(self, d_model, max_len=512):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float() #max_len(문장최대길이)과 d_model 차원을 가진 0으로 이루어진 텐서 생성
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1) #차원을 추가하여 position 생성
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
        #0부터 d_model까지 2씩 증가하여 수열생성, float 타입 변환, log와 나눗셈, 지수 연산을 적용

        pe[:, 0::2] = torch.sin(position * div_term) #pe 짝수 인덱스열에 대해 sin 함수 적용
        pe[:, 1::2] = torch.cos(position * div_term) #pe 홀수 인덱스열에 대해 cos 함수 적용

        pe = pe.unsqueeze(0) #pe 텐서에 차원추가
        self.register_buffer('pe', pe) #계산한 Positional Encoding 값을 모델의 buffer로 등록

    def forward(self, x):
        return self.pe[:, :x.size(1)]

segment embedding

세그먼트 임베딩은 두 개 이상의 문장을 입력으로 받을 때, 각 문장이 어떤 세그먼트에 속하는지를 구분하기 위해 사용된다.  여기서는 세그먼트의 종류가 3가지로 가정한다. 패딩에 해당하는 세그먼트의 인덱스를 0으로 설정함.

import torch.nn as nn


class SegmentEmbedding(nn.Embedding):
    def __init__(self, embed_size=512):
        super().__init__(3, embed_size, padding_idx=0)

token embedding

import torch.nn as nn


class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512):
        super().__init__(vocab_size, embed_size, padding_idx=0)

utils

bert.py

import torch.nn as nn

from .transformer import TransformerBlock
from .embedding import BERTEmbedding


class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): #단어 총합의 사이즈, 은닉층 차원, transformer block수, attention head 수, dropout
        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads

        # paper noted they used 4*hidden_size for ff_network_hidden_size 피드포워드 신경망 4개의 은닉층
        self.feed_forward_hidden = hidden * 4

        # embedding for BERT, sum of positional, segment, token embeddings  bert 임베딩-positional, segment, token 임베딩 합
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)

        # multi-layers transformer blocks, deep network 
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # attention masking for padded token 패딩된 토큰의 attention 마스킹
        # torch.ByteTensor([batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) #입력 x에서 0보다 큰요소(패딩이 아닌 토큰 위치 ) mask 생성

        # embedding the indexed sequence to sequence of vectors 벡터의 시퀀스에 index 시퀀스 임베딩
        x = self.embedding(x, segment_info)

        # running over multiple transformer blocks 
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)

        return x

language_model.py

BERTLM 클래스는 Bert 기반으로하는 언어 모델을 정의한다. 

이 클래스는 next prediction sentence, masked language model을 포함한다.

next sentence prediction은 isnext, notnext  두개로 classification 한다.

masked language model은 mask된 토큰을 예측한다. 

import torch.nn as nn

from .bert import BERT


class BERTLM(nn.Module):


    def __init__(self, bert: BERT, vocab_size): #훈련되어야하는 bert 모델,masked_lm의 총 vocab_size

        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.hidden)
        self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label) #x와 segment_label을 BERT 모델에 통과시킨 후, 그 출력을 다음 문장 예측 모델과 마스킹된 언어 모델에 각각 통과
        return self.next_sentence(x), self.mask_lm(x)


class NextSentencePrediction(nn.Module): #2개의 classification : isnext, notnext

    def __init__(self, hidden): #hidden: bert 모델의 아웃풋 사이즈
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x[:, 0]))


class MaskedLanguageModel(nn.Module): #mask된 토큰에서 원래의 토큰을 예측
  
    def __init__(self, hidden, vocab_size):
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

transformer.py

양방향 인코더인 bert의 transformerblock은 다음과 같은 구성요소를가지고 있다

MultiHeadedAttention( 여러개의 attention head로 구성)

PositionwiseFeedForward( fully connected feed-forward network)

SublayerConnection( 서브 레이어에 residual connection을 적용하고, 그 결과를 normalization)

 

multihead attention과 pointwise feedforward layer는 아래 attention is all you need의 코드를 참고하면 된다.

sublayer connection은 EncoderLayer이다.

https://coldjellyfish0227.tistory.com/76

 

Attention is all you need (NeurIPS, 2017) code review

https://www.youtube.com/watch?v=AA621UofTUA&t=2706s 동빈나 님의 [딥러닝 기계 번역] Transformer: Attention Is All You Need (꼼꼼한 딥러닝 논문 리뷰와 코드 실습) 을 보고 코드를 리뷰해봤다. 설명을 너무 잘해주신

coldjellyfish0227.tistory.com

import torch.nn as nn

from .attention import MultiHeadedAttention
from .utils import SublayerConnection, PositionwiseFeedForward


class TransformerBlock(nn.Module):


    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):# hidden: transformer의 은닉층 사이즈,multi-head attention의 head 크기, feed_forward_hidden:  4* 은닉층

        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)

 

attention에서 한층 더 발전된 모델이다 보니 정의해야할 함수들이 늘어났다. 

유튜브 보고 bert 자체 프레임워크 써서 코드 하나 따라해볼까 싶다,,,, 논문은 이해가는데 class문으로 통으로 구현하라고 하면 어려운것 같다 ,,,, 더 공부해야지