TensorFlow XLA初步接触

Tensorflow越来越像一个编译器,把计算图编译为可执行代码。其中关键的部分就是XLA (Accelerated Linear Algebra)。我在实际使用中真切感受到了XLA带来的提升,希望对XLA能有更多的了解,因此花了点时间探索了一下。


XLA框架

关于XLA,Tensorflow给出了比较简略的说明。XLA主要是用来提升计算速度、节省内存(显存)等。XLA的输入语言称为“HLO (High Level Optimizer)”,HLO定义了整个计算图。随后,XLA对HLO进行一些机器无关、高层的优化,然后用LLVM等进行机器相关、底层的优化并生成代码。这个流程如下图所示。

  1. 描绘计算图HLO,这一步可以通过tf2xla、xla_client等完成
  2. 对HLO进行广义优化(机器无关),如CSELoop Fusion等编译常用优化策略
  3. 针对特定设备,对HLO进行优化
  4. LLVM等生成可执行代码
XLA流程,见https://www.tensorflow.org/xla/overview

Operation Semantics

我不是非常理解Operation Semantics是什么意思,有兴趣可以看看。我只知道HLO支持很多操作,其中比较容易接受的是Element-wise unary functions(包括abs、cos等)、Element-wise binary arithmetic operations(如相加、相乘等)……

跟这一部分比较相似的是TensorRT定义的操作,如IActivationLayerIElementWiseLayer等。其实神经网络很简单,靠这25个操作就能定义大部分网络。
有了这些定义之后,我们就可以描述一个网络即计算图。有了精确定义的计算图,就可以对其进行优化。


计算图优化

前面也提过,XLA首先进行计算图优化主要是跟机器无关的、高层的。如CSELoop Fusion等编译常用优化策略。下面出现的XLA计算图和及其优化的中间结果,可以通过设置环境变量来导出,然后转换一下即可。

CSE

我们先看看CSE(Common subexpression elimination)。我也把这出成了Byte Camp的题,由于我难以描述清楚题目,没被采用。其实这题很能打“我是搞深度学习的,为什么让我做这么多编程题”的脸。

假设我们在使用Tensorflow等编写神经网络时,为了使代码逻辑清晰,可能会写出如下运行时低效的计算:

在Tensorflow中可以表示为如下左图,其中 p1 / (p0 + (p3 - p4)) 计算了两次。XLA就能对此进行了优化,只需计算一次,计算流程被优化为如下右图形式。
   
我们可以通过简单的程序来完成这一过程,可以看到真实的Tensorflow代码才200行不到。这也是ICPC的一道题,有兴趣可以尝试一下。

Fusion

Fusion可能带来提升,有可能会降低效率。这跟计算和架构相关。但在神经网络和Nvidia的GPU架构下,很难出现效率降低。看一个简单的例子, np.sin(np.cos(a * b) + c) ,其中 a,b,c 都是矩阵。显然通过fusion,我们可以,

  1. 节省存储,提高cache利用率
  2. 减少kernel数

Fusion后的计算表示为CUDA代码,大概是:

可以看到本来要4个CUDA Kernel要完成的计算,Fusion之后1个就行了。这提升还是非常靠谱的,XLA也对此做了优化,如下图所示。显然可以带来计算效率的提升 (gpu额外开销比较大)。

Fusion后Fusion后


BERT XLA

前面的都是随便写的计算,现在可以看看BERT开启XLA后发生了什么。先放一张图。整个BERT的计算图太大了,放不下。这里是一层Transformer,不带训练的情况,其实也够看了(放大看)。

BERT,一层Transformer的计算流程

从上图,我们可以看到Layer norm、GELU都有很多细碎的操作,这如果没有优化会产生很多额外开销和中间结果,带来的后果就是显存占用高。而XLA将这些细碎操作都Fusion在一起了,形成了一个大的Kernel。开启XLA和FP16之后,训练效率是原来的4倍,直接起飞,可能这个加速比还不是理论极限。


XLA Client

还有一个比较有意思的是XLA Client。这非常硬核,我们可以直接将numpy代码转成在GPU上运行的代码,并且附带计算图优化,完成了下面几个项目大部分功能。

  1. https://github.com/andersbll/cudarray
  2. https://github.com/cupy/cupy
  3. https://devblogs.nvidia.com/numba-python-cuda-acceleration/
  4. https://github.com/dmlc/minpy

具体可以参考,

  1. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/BUILD#L228
  2. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/python/xla.py
  3. https://github.com/google/jax

这里简单使用一下xla client(基于tf 1.13.1),由于这个功能还不稳定,这个代码随时跑不起来。

计算速度如下,

计算图我就不放了,这个页面太大了,已经很卡了。。。有兴趣可以自己输出计算图看看。

 

参考链接

  1. https://www.tensorflow.org/xla/overview
  2. http://pages.di.unipi.it/corradini/Didattica/PR2-B-14/OpSem.pdf
  3. https://www.cs.cmu.edu/~rjsimmon/15411-f15/lec/18-commonsub.pdf