提速的关键:
1. 批量处理
Transformers 实现了一种批处理算法,其中单个音频样本被分成 30 秒的片段,然后分批转录这些块。这种批处理算法比 OpenAI(按顺序转录块)提供高达 7 倍的增益
2. JAX优于PyTorch
JAX 是一个用于高性能机器学习研究的自动微分库,通过即时 (JIT) 编译 Whisper,比PyTorch在 GPU 上获得了 2 倍的速度提升
3. TPUs 优于 GPUs
张量处理单元 (TPU) 是由 Google 设计的 ML 加速器, TPU 专为矩阵乘法而构建,与更通用的 GPU 相比具有显着优势。在 TPU v4-8 上运行 Whisper JAX 比在 NVIDIA A100 上快 5 倍!
全部加在一起:批处理 7 倍 JAX 2 倍 TPU 5 倍速度增益 => 整体速度提升 70 倍
paper | demo | repo