CS336_Assignment 1 (basics)
Assignment Overview
course web:https://stanford-cs336.github.io/spring2025/
这个homework有一定的算力要求毕竟要训练model
  CS336: Language Modeling from Scratch,是Stanford开的一门课,目的是带领学生从头构建llm,通过尽量少调用现有库的方式手搓llm。
   事实上由于我的大部分知识学习与工作都是通过即用即学的方式完成的,这种系统的学习确实是我所缺乏的,我将通过五个homework尽量补全llm所缺乏的知识,并希望以此提升我的code skill,方便开发新的work。
   CS336的第一个homework被叫做basic,但是我看到这个homework的overview就觉得不是很basic,事实上这个homework涉及到的theory确实是basic的,但是完成的方式和basic还是有一定区别。
   这个homework要完成以下几件事:
- 实现一个BPE分词器
- 实现transformer的各个组件
- 实现loss function、optimizer(经典的Adam)、scheduler和整个train过程
- 使用不同数据集、调整超参和消融实验
  以上这些任务都强调不使用封装好的torch库真是充满手搓的暴力美学,一看就不是很basic...
Byte-Pair Encoding (BPE) Tokenizer
  这部分就是经典的tokenizer的实现了,就是要实现字符 -> 字节 -> 整数序列的转换,事实上就是要实现从自然语言到机器语言的转换,具体可以参考这两篇经典文章:
Neural Machine Translation of Rare Words with Subword Units
Neural Machine Translation with Byte-Level Subwords
The Unicode Standard
  Unicode标准是规定字符->整数的一种标准。
   在Python中,可以使用ord()函数将单个Unicode字符转换为其整数表示形式。chr()函数将整数Unicode代码点转换为具有相应字符的字符串。
# Unicode example
>>> ord('牛')
29275
>>> chr(29275)
'牛'Answer Problem (unicode1):
# A: What Unicode character does chr(0) return?
# Q: '\x00'
>>> chr(0)
'\x00'
# A: How does this character’s string representation (__repr__()) differ from its printed representa-tion?
# Q: It's a invisible string, so no output.
>>> print(chr(0))
# A: What happens when this character occurs in text? 
# It may be helpful to play around with the following in your Python interpreter and see if it matches your expectations:
# >>> chr(0)
# >>> print(chr(0))
# >>> "this is a test" + chr(0) + "string"
# >>> print("this is a test" + chr(0) + "string")
# Q:
>>> chr(0)
'\x00'
>>> print(chr(0))
>>> "this is a test" + chr(0) + "string"
'this is a test\x00string'
>>> print("this is a test" + chr(0) + "string")
this is a teststringUnicode Encodings
  Unicode似乎已经可以成功帮我们把string转成interger了,似乎已经可以帮我们转成可训练的机器语言了。
   但是Unicode的dictionary本身有两个缺陷:
- dictionary本身很大(150K)
- dictionary会很稀疏(经常使用的词汇命中率会很高)
  我们将使用Unicode编码,它将Unicode字符转换为字节序列。
   Unicode标准本身定义了三种编码:UTF-8、UTF-16和UTF-32,UTF-8是Internet的主要编码。
UTF-8: 变长字符编码,被定义为将码点编码为 1 至 4 个字节,具体取决于码点数值中有效二进制位的数量UTF-16: 变长字符编码, 这种编码方式比较特殊, 它将字符编码成 2 字节 或者 4 字节UTF-32: 固定长度的编码,始终占用 4 个字节,足以容纳所有的 Unicode 字符,所以直接存储 Unicode 码即可,不需要任何编码转换
Unicode字符串编码与反编码,可以通过Python自带的原语实现:
- encode(): Unicode字符串编码为UTF-8
- decode(): UTF-8字节字符串解码为Unicode字符串
>>> test_string = "hello! こんにちは!"
>>> utf8_encoded = test_string.encode("utf-8")
>>> print(utf8_encoded)
b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'
>>> print(type(utf8_encoded))
<class 'bytes'>
>>> list(utf8_encoded)
[104, 101, 108, 108, 111, 33, 32, 227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175, 33]
>>> print(len(test_string))
13
>>> print(len(utf8_encoded))
23
>>> print(utf8_encoded.decode("utf-8"))
hello! こんにちは!通过将我们的Unicode代码点转换为字节序列,我们实际上是在获取代码点序列(0到154,997范围内的整数)并将其转换为字节值序列(0到255范围内的整数)。
Answer Problem (unicode2):
# Q: What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? 
# It may be helpful to compare the output of these encodings for various input strings.
# A: UTF-8 is a dynamic length code, so it just need a byte for ASCII, making it efficient for English.
# Q: Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into a Unicode string. 
# Why is this function incorrect? Provide an example of an input byte string that yields incorrect results.
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])
>>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))
'hello'
# A: The functino just support to encode utf-8 byte by byte, but it will go wrong when we decode muti-byte string such as "你好". 
# Q: Give a two byte sequence that does not decode to any Unicode character(s).
# A: b'\xc3\x41' is a illegal two byte sequence.Subword Tokenization
字节级标记化可以缓解单词级标记器面临的词汇表外问题,但将文本标记化为字节会导致极长的输入序列。这会带来两点坏处:
- 处理这些较长的序列需要模型的每一步都需要更多的计算。
- 较长的输入序列会在数据中产生长期依赖关系。
  为了应对这个问题,我们采用子词标记化的方式,这是单词级标记器和字节级标记器之间的中点。
  子词标记器权衡了更大的词汇表大小,以更好地压缩输入字节序列。如果字节序列the经常出现在我们的原始文本训练数据中,那么在词汇表中为其分配一个条目,会将这个3个标记序列减少为单个标记。
  我们可以通过BPE算法选择这些子词单元来添加到我们的词汇表中,具体来说,就是迭代地用单个新的未使用索引合并出现最频繁的字节对。构建BPE标记器词汇表的过程称为train BPE tokenizer。
BPE Tokenizer Training
BPE的train过程主要包括三个step。
Vocabulary initialization
  tokenizer vocabulary是从字节串标记到整数ID的一对一映射, 由于我们使用byte级别的BPE tokenizer,所以初始vocabulary大小是256(UFT-8的0-255号字节值)
Pre-tokenization
有vocabulary之后,我们就可以计算在文本中字节相邻出现的频率,从出现最频繁的byte pair开始合并,但是这存在两点问题,使我们一般不这么做:
- 每次合并都需要对text全面检查,这会造成巨大的开销
- byte pair可能只会在标点上有区别,却会被标记为不同的byte pair,这会导致很多byte pair其实会有很高的相似性
   为了避免这个问题,我们使用pre-tokenization的方法进行处理。事实上就是对text的粗粒度标记。text可能是出现10次的pre-tokenization,那在train BPE tokenizer 统计t e 相邻出现次数时,直接加10就好了,无需再查看context。
   原始的BPE实现用(s""),即空格实现pre-tokenization的划分,我们可以采用更为经典的基于正则表达式的pre-tokenization(used by GPT-2 from https://github.com/openai/tiktoken/pull/234/files)
# a useful regex
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""上文的regex是一个非常便捷的拆分文本的正则表达式,我们可以通过一个例子进行演示:
>>> import regex as re
>>> PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# use findall
>>> re.findall(PAT, "some text that i'll pre-tokenize")
['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']
# use finditer is recommended, for flexble usage of memory
>>> [match.group(0) for match in re.finditer(PAT, "some text that i'll pre-tokenize")]         
['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']Compute BPE merges
现在我们已经将text转换为pre-tokenization,并将每个pre-tokenization表示为UTF-8字节的序列,我们可以开始计算BPE merge(即train BPE tokenizer)。
   BPE算法迭代计算byte pair并识别频率最高的pair("A"、"B")。然后合并这个最频繁pair("A"、"B")的每次出现,即替换为一个新的标记"AB"。这个新的合并token被添加到我们的vocabulary中;因此,BPE训练后的最终vocabulary是初始vocabulary的大小,加上训练期间执行的BPE合并操作的数量。
注意:在计算合并时,通过优先选择dictionary上更大的pair来确定地打破对频率上的联系。 eg:max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")]) ('BA', 'A')
Special tokens
   一些special token 被加入vocabulary是必要的(例如<|endoftext|>),special token可以用作区分text开始与结尾的label等等。
# bpe example
def pre_tokenize(input_text: str, special_tokens: set):
    split_words = input_text.split(' ')
    pre_tokenization = {}
    
    for word in split_words:
        if word in special_tokens:
            pre_tokenization[(word,)] = pre_tokenization.get((word,), 0) + 1
        else:
            word_as_tuple = tuple(word) 
            pre_tokenization[word_as_tuple] = pre_tokenization.get(word_as_tuple, 0) + 1
            
    return pre_tokenization
def get_pair_frequencies(input : dict):
    output = {}
    for word, count in input.items():
        for i in range(len(word) - 1):
            left_char = word[i]
            right_char = word[i + 1]
            pair = (left_char, right_char)
            output[pair] = output.get(pair, 0) + count
    return output
def find_key(input : dict):
    max_value = max(input.values())
    max_key = [k for k, v in input.items() if v == max_value]
    return max(max_key)
def merge(vocabulary : dict, pair):
    new_vocab = {}
    for word_tuple, count in vocabulary.items():
        new_word_tuple = merge_single_word(word_tuple, pair)
        new_vocab[new_word_tuple] = count
    return new_vocab
def merge_single_word(word_tuple, pair_to_merge):
    new_word = list()
    i = 0
    while i < len(word_tuple):
        if i < len(word_tuple) - 1 and (word_tuple[i], word_tuple[i + 1]) == pair_to_merge:
            new_word.append(word_tuple[i] + word_tuple[i + 1])
            i += 2
        else:
            new_word.append(word_tuple[i])
            i += 1
    return tuple(new_word)
def main():
    special_tokens = {'<|endoftext|>'}
    input_text = "low low low low low lower lower widest widest widest newest newest newest newest newest newest"
    num_merges = 6 # 执行6次合并
    
    # 1. 初始化
    vocab = pre_tokenize(input_text, special_tokens) # 你的 pre_tokenization 字典
    
    # 2. 循环执行合并
    for i in range(num_merges):
        print(f"--- 合并第 {i+1} 次 ---")
        
        # 2a. 统计当前词汇表中的字节对频率
        pair_freqs = get_pair_frequencies(vocab)
        
        # 2b. 找到最佳对
        if not pair_freqs: # 如果没有对可以合并了
            break
        best_pair = find_key(pair_freqs)
        
        # 2c. 执行合并,更新词汇表
        vocab = merge(vocab, best_pair)
        
        print(f"合并了: {best_pair}")
        print("合并后的词汇表:")
        print(vocab)
    print("\n--- 最终结果 ---")
    print(vocab)
if __name__ == '__main__':
    main()Experimenting with BPE Tokenizer Training
Parallelizing pre-tokenization
  在训练前,我们可以发现一个主要的瓶颈是pre-tokenization,大文本量和不确定的分界线造就了这一点。
   幸运的是,课程已经为我们准备好了一个example程序cs336_basics/pretokenization_example.py供我们参考。
Removing special tokens before pre-tokenization
  在使用正则表达式模式(使用re.finditer)运行pre-tokenization之前,应该从text中删除所有特殊标记。
   确保对特殊标记进行拆分,以便它们分隔的文本之间不会发生合并。 函数test_train_bpe_special_tokens将检查这一点。
Optimizing the merging step
  BPE在每次合并时,它都会迭代所有字节对以识别最频繁的pair。然而,每次合并后唯一改变的pair count是那些与合并对重叠的对。
   BPE训练速度可以通过索引所有对的计数并逐步更新这些计数来提高,而不是显式地迭代每对字节来计算对频率。
Problem (train_bpe):
提供一个通过test的example,建议自行完成,尤其需要思考merge的优化和正则匹配的分割:
from .pretokenization_example import find_chunk_boundaries
import regex as re
import os
from typing import Dict, List, Tuple
def pre_tokenize(input_text: str, special_tokens: List[str]) -> Dict[Tuple[bytes, ...], int]:
    word_counts = {}
    
    # 构建一个能分割文本的、包含所有特殊token的正则表达式
    # re.escape确保了特殊token中的元字符被正确处理
    # f"({ ... })" 创建了一个捕获组,这样 re.split 可以在结果中保留分隔符
    special_pattern = "|".join(re.escape(s) for s in special_tokens)
    special_splitter = re.compile(f'({special_pattern})')
    base_pattern = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
    # 1. 按照特殊token分割文本
    text_parts = special_splitter.split(input_text)
    # 2. 遍历分割后的部分
    for i, part in enumerate(text_parts):
        if not part:  # 跳过 re.split 可能产生的空字符串
            continue
            
        if i % 2 == 1:  # 奇数索引是分隔符本身,即特殊token
            token_bytes = part.encode('utf-8')
            word_tuple = (token_bytes,)
            word_counts[word_tuple] = word_counts.get(word_tuple, 0) + 1
        else:  # 偶数索引是普通文本
            # 对普通文本部分应用基础的正则表达式
            for match in base_pattern.finditer(part):
                token_str = match.group(0)
                token_bytes = token_str.encode('utf-8')
                word_tuple = tuple(bytes([b]) for b in token_bytes)
                word_counts[word_tuple] = word_counts.get(word_tuple, 0) + 1
                
    return word_counts
def get_pair_frequencies(input : dict):
    output = {}
    for word, count in input.items():
        for i in range(len(word) - 1):
            left_char = word[i]
            right_char = word[i + 1]
            pair = (left_char, right_char)
            output[pair] = output.get(pair, 0) + count
    return output
def find_key(input : dict):
    max_value = max(input.values())
    max_key = [k for k, v in input.items() if v == max_value]
    return max(max_key)
def merge_single_word(word_tuple, pair_to_merge):
    new_word = list()
    i = 0
    while i < len(word_tuple):
        if i < len(word_tuple) - 1 and (word_tuple[i], word_tuple[i + 1]) == pair_to_merge:
            new_word.append(word_tuple[i] + word_tuple[i + 1])
            i += 2
        else:
            new_word.append(word_tuple[i])
            i += 1
    return tuple(new_word)
def bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str],
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """
    训练一个字节级别的BPE分词器。
    """
    # --- 1. 初始化 ---
    final_vocab: Dict[int, bytes] = {}
    merges: List[Tuple[bytes, bytes]] = []
    
    # 首先添加特殊tokens到词汇表开头
    current_id = 0
    for token_str in special_tokens:
        final_vocab[current_id] = token_str.encode('utf-8')
        current_id += 1
    
    # 然后填充基础词汇表 (0-255字节)
    for i in range(256):
        final_vocab[current_id] = bytes([i])
        current_id += 1
        
    # --- 2. 全局预分词和频率统计 ---
    word_counts: Dict[Tuple[bytes, ...], int] = {}
    
    with open(input_path, "rb") as f:
        boundaries = find_chunk_boundaries(
            f, 32, "<|endoftext|>".encode("utf-8"))
            
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk_text = f.read(end - start).decode("utf-8", errors="ignore")
            
            # 对每个块进行预分词
            chunk_word_counts = pre_tokenize(chunk_text, special_tokens)
            
            # 聚合到全局词频统计中
            for word, count in chunk_word_counts.items():
                word_counts[word] = word_counts.get(word, 0) + count
    # --- 3. BPE 迭代合并 ---
    num_merges_needed = vocab_size - len(final_vocab)
    for i in range(num_merges_needed):
        pair_freqs = get_pair_frequencies(word_counts)
        if not pair_freqs:
            break
        
        best_pair = find_key(pair_freqs)
        merges.append(best_pair)
        
        words_to_update = {}
        # 1. 找出所有包含 best_pair 的词
        for word, count in word_counts.items():
            if best_pair[0] in word and best_pair[1] in word:
                # 这是一个潜在的候选词,但我们还需要检查它们是否相邻
                for j in range(len(word) - 1):
                    if (word[j], word[j+1]) == best_pair:
                        words_to_update[word] = count
                        break # 找到一个匹配就够了,处理下一个词
        
        # 2. 更新这些词
        for word, count in words_to_update.items():
            # 从旧字典中移除
            del word_counts[word]
            # 创建新词并添加到字典中
            new_word = merge_single_word(word, best_pair)
            word_counts[new_word] = word_counts.get(new_word, 0) + count
        new_token_bytes = best_pair[0] + best_pair[1]
        final_vocab[current_id] = new_token_bytes
        current_id += 1
    return final_vocab, merges