NLPで変わる私の人生

NLPにおけるTransformerの基本:Self-AttentionメカニズムとPythonでの実装

Tags: NLP, Transformer, Self-Attention, Python, 深層学習

はじめに:なぜ今、Transformerを学ぶのか

自然言語処理(NLP)の分野は近年、目覚ましい発展を遂げています。特に、テキストデータに含まれる複雑な文脈や依存関係を捉える能力は飛躍的に向上し、機械翻訳、テキスト要約、質問応答といった様々なタスクで人間レベルの性能に迫る成果が報告されています。この発展を牽引しているのが、2017年にGoogleが発表した「Transformer」モデルです。

これまでの系列データ処理では、RNN(Recurrent Neural Network)やその発展形であるLSTM(Long Short-Term Memory)、GRU(Gated Recurrent Unit)が主流でした。これらのモデルは、時系列データの特性を捉えるのに有効でしたが、長い系列の情報を処理する際に勾配消失問題や計算の並列化が困難であるといった課題を抱えていました。

Transformerは、これらの課題を克服するために「Attentionメカニズム」を全面的に採用し、Recurrent(再帰的)な構造を排除した画期的なアーキテクチャです。本記事では、NLP分野の学習を始めたばかりの皆様に向けて、Transformerの核となるSelf-Attentionメカニズムの概念をわかりやすく解説し、Pythonを用いた具体的な実装例を通じて、その理解を深めていきます。

Transformerの全体像

Transformerモデルは、主にEncoderとDecoderという2つの主要なブロックで構成されています。この構造は、Seq2Seq(Sequence-to-Sequence)モデルに似ていますが、RNNのような再帰的な接続は持ちません。

そして、このEncoderとDecoderの各層で中心的な役割を果たすのが「Self-Attentionメカニズム」です。

Attentionメカニズムの基礎:Self-Attentionとは

Attention(注意)メカニズムは、系列データ処理において、入力系列のどの部分が重要であるか、またはどの部分に注目すべきかを学習する手法です。Transformerでは、このAttentionメカニズムを「Self-Attention」として進化させ、自身の入力系列内でどの単語が他の単語と関連性が高いかを計算します。これにより、離れた単語間の依存関係(長距離依存性)も効率的に捉えることができるようになります。

Self-Attentionの計算には、3つの異なるベクトルが用いられます。

  1. Query (Q) ベクトル: 現在処理している単語が、他の単語を探すための「問い合わせ(Query)」を表します。
  2. Key (K) ベクトル: 他の単語が、現在のQueryに対してどれだけ関連性があるかを示す「鍵(Key)」を表します。
  3. Value (V) ベクトル: 関連性が高いと判断された単語から抽出される「情報(Value)」を表します。

これらのQ, K, Vベクトルは、入力単語の埋め込みベクトル(Word Embedding)から、それぞれ異なる線形変換(重み行列をかける)によって生成されます。

Scaled Dot-Product Attention

Self-Attentionの中核をなす計算は、Scaled Dot-Product Attentionと呼ばれます。これは、以下の数式で表されます。

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

この数式を分解して見ていきましょう。

  1. $QK^T$ (QueryとKeyの内積): 各Queryベクトルが、すべてのKeyベクトルとどれだけ似ているか(関連性が高いか)を計算します。内積が大きいほど、関連性が高いとみなされます。
  2. $\sqrt{d_k}$ でスケーリング: $d_k$ はKeyベクトルの次元数です。内積の値が大きくなりすぎると、softmax関数が勾配消失を起こしやすくなるため、これを緩和するためにスケーリングを行います。
  3. softmax関数: スケーリングされた内積の結果を、合計が1になるような確率分布に変換します。これにより、どの単語にどれくらいの「注意(Attention)」を向けるべきかが数値化されます。
  4. $V$ (Valueとの乗算): 計算されたAttentionスコア(確率分布)を、対応するValueベクトルにかけ合わせます。これにより、関連性の高い単語のValue情報が強く、低い単語のValue情報が弱く反映された、新たな表現ベクトルが生成されます。

PythonによるSelf-Attentionの概念的な実装

ここでは、PyTorchライブラリを用いて、Scaled Dot-Product Attentionの計算プロセスを概念的に見ていきましょう。実際の実装では、効率化のためにバッチ処理や並列計算が考慮されますが、ここでは単一のデータに焦点を当てます。

まず、ダミーのQ, K, Vベクトルを準備します。ここでは、各単語の埋め込み次元を$d_{model}=512$と仮定し、シーケンス長を4とします。Keyの次元数 $d_k$ も埋め込み次元と同じく512とします。

import torch
import torch.nn.functional as F

# ダミーの入力系列(例: 4単語、各単語の埋め込み次元は512)
# Q, K, Vは同じ入力から線形変換で生成されると仮定
# shape: (sequence_length, d_model)
d_model = 512
d_k = 512
sequence_length = 4

# 例として、ランダムなQ, K, Vを生成
# 実際のモデルでは、nn.Linearレイヤーを通じて入力からこれらのベクトルが生成されます。
query = torch.randn(sequence_length, d_model) # (4, 512)
key = torch.randn(sequence_length, d_model)   # (4, 512)
value = torch.randn(sequence_length, d_model) # (4, 512)

print("Query shape:", query.shape)
print("Key shape:", key.shape)
print("Value shape:", value.shape)

# 1. QとKの転置の内積を計算 (Query-Key similarity)
# (sequence_length, d_model) @ (d_model, sequence_length) -> (sequence_length, sequence_length)
scores = torch.matmul(query, key.transpose(-2, -1))
print("\nScores shape (QK^T):", scores.shape)
print("Scores (部分):\n", scores[0, :]) # 各単語が他の単語との関連性を示すスコア

# 2. スケーリング
scaled_scores = scores / (d_k ** 0.5)
print("\nScaled Scores (部分):\n", scaled_scores[0, :])

# 3. Softmaxを適用してAttention Weightsを計算
attention_weights = F.softmax(scaled_scores, dim=-1)
print("\nAttention Weights shape:", attention_weights.shape)
print("Attention Weights (部分):\n", attention_weights[0, :]) # 各行の合計が1になる

# 4. Attention WeightsとValue行列を乗算
# (sequence_length, sequence_length) @ (sequence_length, d_model) -> (sequence_length, d_model)
output = torch.matmul(attention_weights, value)
print("\nOutput shape (Attention(Q,K,V)):", output.shape)
print("Output (部分):\n", output[0, :5]) # 最初の単語のAttention出力の一部

上記のコードでは、各単語が他のすべての単語(自分自身も含む)との関連性を計算し、その関連性に基づいて情報を集約していることがわかります。これがSelf-Attentionの基本的な仕組みです。

Multi-Head Attention

Transformerでは、Attentionメカニズムを一度だけ適用するのではなく、並列に複数回実行します。これを「Multi-Head Attention(マルチヘッド・アテンション)」と呼びます。

各「ヘッド」は、異なる線形変換(異なるQ, K, Vの重み行列)を適用することで、入力情報から異なる種類の関連性やパターンを学習します。例えば、あるヘッドは文法的な関係に注目し、別のヘッドは意味的な関連性に注目するといった具合です。

複数のヘッドから得られたAttentionの出力を結合(concatenate)し、再度線形変換することで、より豊かな表現(Representation)を獲得します。

import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads # 各ヘッドの次元
        self.num_heads = num_heads
        self.d_model = d_model

        # Q, K, Vのための線形変換層
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        # 最終的な出力のための線形変換層
        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        # Q, K, Vを線形変換し、num_headsに分割
        # shape: (batch_size, sequence_length, d_model) -> (batch_size, sequence_length, num_heads, d_k)
        # -> (batch_size, num_heads, sequence_length, d_k)
        query = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        key = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        value = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attentionを計算
        # (batch_size, num_heads, sequence_length, d_k) @ (batch_size, num_heads, d_k, sequence_length)
        # -> (batch_size, num_heads, sequence_length, sequence_length)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d_k ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)

        # Attention WeightsとValueの乗算
        # (batch_size, num_heads, sequence_length, sequence_length) @ (batch_size, num_heads, sequence_length, d_k)
        # -> (batch_size, num_heads, sequence_length, d_k)
        output = torch.matmul(attention_weights, value)

        # 各ヘッドの出力を結合し、最終的な線形変換
        # shape: (batch_size, sequence_length, num_heads, d_k)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.fc_out(output)
        return output

# Multi-Head Attentionの利用例
d_model = 512
num_heads = 8
batch_size = 2
sequence_length = 10

# ダミーの入力テンソル
# (batch_size, sequence_length, d_model)
input_tensor = torch.randn(batch_size, sequence_length, d_model)

mha = MultiHeadAttention(d_model, num_heads)
mha_output = mha(input_tensor, input_tensor, input_tensor)

print("\nMulti-Head Attention Output shape:", mha_output.shape)

このMultiHeadAttentionの実装では、query, key, valueが同じinput_tensorから来ているため、Self-Attentionとして機能します。

Positional Encoding:単語の位置情報の付加

Transformerは再帰的な構造を持たないため、系列内の単語の順序に関する情報を直接は認識できません。この問題を解決するために、「Positional Encoding(位置エンコーディング)」という手法が用いられます。

Positional Encodingは、単語の埋め込みベクトルに、その単語の絶対的な位置や相対的な位置を示すベクトルを加算するものです。これにより、各単語は自身のセマンティックな意味だけでなく、系列内での位置情報も含むようになります。Transformerの論文では、サイン関数とコサイン関数を用いたPositional Encodingが提案されており、学習することなく位置情報を表現できるという特徴があります。

import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # バッチ次元を追加
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x.shape: (batch_size, sequence_length, d_model)
        # pe.shape: (1, max_len, d_model)
        # xのシーケンス長に合わせてpeを切り出す
        x = x + self.pe[:, :x.size(1)]
        return x

# Positional Encodingの利用例
d_model = 512
sequence_length = 10
batch_size = 2

# ダミーの単語埋め込み入力
word_embeddings = torch.randn(batch_size, sequence_length, d_model)

pos_encoder = PositionalEncoding(d_model)
output_with_pos = pos_encoder(word_embeddings)

print("\nWord Embeddings shape:", word_embeddings.shape)
print("Output with Positional Encoding shape:", output_with_pos.shape)
print("最初の単語の埋め込みと位置エンコーディングの合計 (一部):\n", output_with_pos[0, 0, :5])

このPositional Encodingは、単語埋め込み層の出力に加算され、TransformerのEncoderやDecoderの入力として使用されます。

まとめと次のステップ

本記事では、Transformerモデルの中核をなすSelf-Attentionメカニズムについて、その概念とScaled Dot-Product Attention、Multi-Head Attention、そして位置情報を付与するPositional Encodingの基本的な仕組みを、Pythonによる実装例とともに解説しました。

これらの要素が組み合わさることで、Transformerは多様なNLPタスクで優れた性能を発揮します。本記事で得た基礎知識を基に、より複雑なTransformerアーキテクチャ(Encoder-Decoder全体)や、BERT、GPTといったTransformerベースの事前学習モデルについて学習を進めることで、最先端のNLP研究や開発に一歩踏み出すことができるでしょう。

参考文献