兄弟们,最近搞Llama2-13B部署,单卡A100总显存吃满还报OOM?别急着加卡,这几个骚操作实测能省40%-60%显存。
第一招:量化剪枝。FP16转INT8/INT4,权重直接砍半。用bitsandbytes或GPTQ,模型精度掉不到1个点,输出质量基本不变。但注意量化后要校准数据集,否则某些任务崩得妈都不认。
第二招:梯度重计算。训练时把中间激活值扔掉,反向传播再重新算。显存占用从O(n)降到O(1),但训练时间增加15%-20%。适合单卡穷玩党,土豪无视。
第三招:FlashAttention。这玩意儿把注意力矩阵分块计算,省掉O(n²)内存。HuggingFace 4.36以上直接支持,效果立竿见影。配合vLLM搞推理,吞吐能翻倍。
以上三板斧全上,13B模型能在24G卡上跑出16K上下文。别问我怎么知道的,刚踩完坑。
提问:你们遇到的最大显存瓶颈是哪个环节?是层数太深、序列太长,还是优化器状态?评论区聊聊。 |