标签 | 数据集 | 指标 | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
|
该论文介绍了 LongRoPE,一种将大型语言模型 (LLM) 的上下文窗口扩展到超过 200 万个标记的方法。
关键思想是:
识别并利用位置嵌入中的两种形式的不均匀性,以最大程度地减少插值期间的信息丢失。这无需微调即可实现 8 倍上下文扩展。
使用高效的渐进式扩展策略,通过 256k 微调达到 2048k 上下文,而不是直接微调极大的上下文。
调整较短上下文的嵌入以恢复原始窗口大小内的性能。
该方法适用于 LLaMA2 和 Mistral。跨各种任务的实验证明了 LongRoPE 在保持 4k 到 2048k 上下文长度的性能方面的有效性。
Transformer 架构面临着自注意力的二次计算复杂性以及缺乏对训练时未见的 token 位置的泛化的困扰。为了将自注意力计算扩展到大的上下文中,人们提出了各种方法,例如 RoPE、AliBi、注意力池等。尽管如此,这些解决方案都不能有效地扩展到具有数百万个令牌的上下文,同时保持模型的准确性。
本文提出了一种新技术 LongRoPE,将 LLM 的上下文窗口扩展到超过 200 万个令牌。
LongRoPE 利用渐进式扩展策略来获得 2048k 上下文窗口,而无需对极长的文本进行直接微调,而极长的文本既罕见又难以获得。该策略首先对预训练的 LLM 进行 256k 扩展,然后在此长度上进行微调。
为了解决原始(较短)上下文窗口中潜在的性能下降问题,LongRoPE 进一步调整扩展 LLM 上的 RoPE 重新缩放因子,使用其搜索算法在 256k 微调 LLM 上缩小到 4k 和 8k 上下文窗口,以最大限度地减少位置插值。在对长度低于 8k 的序列进行推理期间,RoPE 会使用这些精心搜索的重缩放因子进行更新。
对各种法学硕士和需要长上下文的任务进行的测试验证了 LongRoPE 的功效。该方法在从 4k 到 2048k 令牌的评估长度上显着保持了较低的复杂性,在密钥检索中实现了 90% 以上的准确度,并提供了与 4096 个上下文窗口内的标准基准相当的准确度
深入研究结构修改及其对模型性能的影响。
LongRoPE模型架构旨在将大型语言模型(LLM)的上下文窗口扩展到超过200万个令牌,解决传统Transformer架构的局限性。关键创新在于渐进式扩展策略和位置嵌入的调整。
关键组件包括:
class RoPEPositionalEncoding ( nn . Module ):
def __init__ ( self , d_model , max_len = 1000000 , base = 10000 ):
super (). __init__ ()
self . d_model = d_model
self . max_len = max_len
self . base = base
self . theta = torch . tensor ([ base ** ( - 2 * ( i // 2 ) / d_model ) for i in range ( d_model )])
def forward ( self , positions ):
angles = positions . unsqueeze ( - 1 ) * self . theta
sin_cos = torch . stack ([ angles . cos (), angles . sin ()], dim = - 1 )
return sin_cos . view ( * sin_cos . shape [: - 2 ], - 1 )
def non_uniform_interpolation ( pos_embed , extension_ratio , lambda_factors , n_hat ):
d_model = pos_embed . shape [ - 1 ]
interpolated_pos = pos_embed . clone ()
for i in range ( d_model // 2 ):
mask = torch . arange ( pos_embed . shape [ - 2 ], device = pos_embed . device ) < n_hat
scale = torch . where ( mask , torch . ones_like ( pos_embed [..., 0 ], device = pos_embed . device ),
1 / ( lambda_factors [ i ] * extension_ratio ))
interpolated_pos [..., 2 * i ] *= scale
interpolated_pos [..., 2 * i + 1 ] *= scale
return interpolated_pos
def progressive_extension ( model , data , base_length , target_length , population_size , num_mutations , num_crossovers , max_iterations ):
# Extend to 128k
lambda_factors_128k , n_hat_128k = search_lambda_factors ( model , data , 128000 / base_length , population_size , num_mutations , num_crossovers , max_iterations )
model = fine_tune ( model , data , 128000 , lambda_factors_128k , n_hat_128k , steps = 400 )
# Extend to 256k
lambda_factors_256k , n_hat_256k = search_lambda_factors ( model , data , 256000 / base_length , population_size , num_mutations , num_crossovers , max_iterations )
model = fine_tune ( model , data , 256000 , lambda_factors_256k , n_hat_256k , steps = 600 )
# Extend to target length
if target_length > 256000 :
final_lambda_factors , final_n_hat = search_lambda_factors ( model , data , target_length / base_length , population_size // 2 , num_mutations // 2 , num_crossovers // 2 , max_iterations // 2 )
model . lambda_factors [ "2048k" ] = final_lambda_factors
model . n_hat [ "2048k" ] = final_n_hat
return model , final_lambda_factors , final_n_hat , lambda_factors_256k , n_hat_256k
该架构从预先训练的法学硕士开始,并逐步扩展其上下文窗口。最初,模型经过微调以处理 256k 令牌的上下文长度。这种渐进方法避免了对极长文本进行直接微调的需要,这种文本很少见,而且处理起来计算成本很高。通过逐渐增加上下文长度,模型可以更有效地适应更长的序列。
为了在不同的上下文长度下保持性能,LongRoPE 调整了旋转位置嵌入 (RoPE)。该模型识别并利用位置嵌入中的不均匀性来最大程度地减少插值期间的信息丢失。这允许 8 倍上下文扩展,而无需进行微调。此外,该模型采用搜索算法在 256k 微调的 LLM 上找到较短上下文(例如 4k 和 8k 令牌)的最佳重新调整因子。这些调整确保模型即使在原始上下文窗口大小内也能保持高性能。
该架构结合了多项结构修改,以有效处理增加的上下文长度:
层缩放:对层的缩放进行调整,以确保上下文窗口增长时的稳定性和性能。
内存管理:采用高效的内存管理技术来处理大的上下文大小,而不会压垮系统资源。
注意力机制:集成了增强的注意力机制,以确保模型能够专注于输入序列的相关部分,即使有扩展的上下文。
Token-wise Attention :引入了 Token-wise 注意力机制来捕获 token 之间的上下文关系,使模型能够更好地理解输入的语义。
实验表明,LongRoPE 在 4k 到 2048k 标记的评估长度上保持了较低的复杂度,并在需要长上下文的任务中实现了高精度。这使得它适用于各种应用,包括上下文学习、长文档摘要和小样本学习。
欲了解更多详细信息,请参阅此处的全文。
深入了解支持 LongRoPE 功能的编码和操作细节。这可能包括说明关键组件的片段或伪代码。
欲了解更多详细信息,请参阅论文。
全面的示例演示了如何利用 LongRoPE 进行各种应用,从文本分析到生成大量文档。
# Example usage
data_path = "path/to/your/dataset"
d_model = 512
n_heads = 8
num_layers = 6
base_length = 4096
target_length = 2048 * 1024
data = load_data ( data_path )
model = LongRoPEModel ( d_model , n_heads , num_layers , base_length )
model = model . extend_context ( data , target_length )
input_ids = torch . randn ( 2 , target_length , d_model )
output = model ( input_ids )
print ( output . shape ) # Expected shape: (batch_size, target_length, d_model)
自定义数据集训练
要训练自定义数据集:
超参数调整 LongRoPE 的性能可能对超参数很敏感。要调整的关键参数包括:
lambda 因子搜索中的population_size
、 num_mutations
和num_crossovers
用于微调gradient_accumulation_steps
以实现训练稳定性的学习率和调度程序参数
我的 LongRoPE 实施实现了以下结果:
困惑:
密钥检索准确性:
准确性:
与基线模型的比较:
@article { ding2024longrope ,
title = { LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens } ,
author = { Ding, Yiran and Zhang, Li Lyna and Zhang, Chengruidong and Xu, Yuanyuan and Shang, Ning and Xu, Jiahang and Yang, Fan and Yang, Mao } ,
journal = { arXiv preprint arXiv:2402.13753 } ,
year = { 2024 }
}
注意:此存储库正在开发中,尚未准备好用于生产使用。请参阅论文了解更多详细信息。