# CINN 基础算子体系

## 基础算子定义

**基础算子：**深度学习领域的复杂算子计算过程一般可被拆解为一系列简单的线性代数操作，它们在不同框架中被称为基础算子、元算子或者元张量函数。

**类比：**复杂算子体系可类比为 CISC 架构，基础算子体系可类比为 RISC 架构。使用复杂算子可最小化模型中的算子数量，但每个算子执行时间较长，且编译器等可做的自动优化空间很小；使用基础算子将导致模型中算子数量显著增多，但每个算子的执行时间较短，其为编译器等留有较为灵活的自动优化空间。

## 基础算子选取准则

* **[P0]原子性：**基础算子应该处于计算逻辑拆分的底层，定义为基础算子的操作不可继续拆分。
    - **理由：**如果一个算子可以继续拆分，那么应该将其组件算子定义为基础算子。

* **[P1]可组合性：**基础算子可用作组件算子，即从中可以构建更复杂算子。
    - **理由：**如果一个基础算子不能被其他复杂算子所复用，需要考虑其存在的必要性和单一构建成本。

* **[P2]可优化性：**将复杂算子拆解为基础算子并非最终目的，而是希望通过更细粒度的算子表示获得更多的优化契机。
    - **理由：**以 BatchNorm 为例，其可被拆解为加/减/乘/平方根运算等。拆解后的细粒度操作可以与其他基础算子进行联合代数化简以及算子融合（可以将隶属于不同复杂算子的基础算子融合到同一个操作中）。

* **[P3]频繁依赖性：**若一个简单算子虽然可由其他基础算子组合得到，但是其表示形式出现在各种复杂算子组件中的频次很高，可考虑将该简单算子也视为基础算子。例如：ex–1、log(x+1)等。
    - **理由：**P3 虽然违背了 P0，但是设计基础算子的初衷是为了组合得到复杂算子并且具备可优化能力。将频繁使用到的简单算子设计为基础算子，可先对其进行良好的优化。谨记，P3 仅针对频繁使用的简单算子。

* **[P4]硬件计算库支持度：**Dot、Convolution、BatchNorm 等操作虽然计算复杂，但大多数硬件厂商提供了对应 API 调用。为了方便使用硬件厂商提供的库函数，这类操作也可定义为基础算子。
    - **理由：**对于硬件厂商提供的常见张量操作，将其视为基础算子，则可被直接替换为库函数调用。若其调用性能不高或根本未提供调用，则可在后续的 DecomposerPass 中进一步执行拆解，即延迟拆解时机。
    - **P4 补充说明：**P4 实质上违背了原则 P0-P2，其可视为一种 fallback 机制。最终是选择使用库函数调用还是进一步拆解取决于拆解行为对整个网络模型的计算性能影响。一般来说，Dot 和 Convolution 操作会偏向使用硬件厂商优化好的库函数，而计算逻辑较为简单的 BatchNorm 操作会在 DecomposerPass 中被进一步拆解以挖掘更多的优化契机。

|        操作       |                                         Intel CPU 库函数支持                                        |                                                 Nvidia GPU 库函数支持                                                 |
|:-----------------:|:--------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------:|
|    Convolution    | dnnl::convolution_forward / dnnl::convolution_backward_weights / dnnl::convolution_backward_data 等 | cudnnConvolutionForward / cudnnConvolutionBackwardFilter / cudnnConvolutionBackwardData 等                            |
|        Dot        | cblas_sgemm / cblas_sgemm_batch 等                                                                  | cublasSgemm / cublasSgemmBatched / cublasGemmEx 等                                                                    |
| BatchNormalzation | dnnl::batch_normalization_forward / dnnl::batch_normalization_backward 等                           | cudnnBatchNormalizationForwardInference / cudnnBatchNormalizationForwardTraining / cudnnBatchNormalizationBackward 等 |

## 基础算子分类

参照 TOSA 和 mhlo 中定义的算子集，并结合 CINN 之前的开发经验，选取 114 个基础算子，按照如下维度进行划分：

1. 常见张量算子（12 个）
2. 激活函数（5 个）
3. elementwise 一元算子（22 个）
4. elementwise 二元算子（20 个）
5. elementwise 三元算子（1 个）
6. 比较算子（3 个）
7. reduction 算子（6 个）
8. broadcast 算子（2 个）
9. data layout 变换算子（6 个）
10. 图像操作算子（1 个）
11. 集合[通信]算子（11 个）
12. 数据节点算子（3 个）
13. 类型转换算子（2 个）
14. 量化操作算子（3 个）
15. 随机数生成算子（3 个）
16. 控制流算子（3 个）
17. 动态 shape 算子（8 个）
18. 爱因斯坦求和约定（2 个）
19. 自定义算子（1 个）

### 1. 常见张量算子（12 个）

以下算子可替换为硬件厂商库函数调用：

* `batch_norm_grad`：批正则化操作的反向计算逻辑。
* `batch_norm_inference`：用于前向推理过程中的批正则化计算逻辑。
* `batch_norm_training`：用于训练过程中的批正则化前向计算逻辑。
* `cholesky`：对输入的对称正定矩阵计算 Cholesky 分解。
* `conv`：对给定输入张量进行卷积计算，包括二维卷积和三维卷积。
* `dot`：对两个输入张量进行矩阵乘法操作。
* `fft`：对输入复数张量进行快速傅里叶变换。
* `rfft`：对输入实数张量进行快速傅里叶变换，输出为复数张量。
* `pool`：对输入张量进行池化操作。
* `sort`：对输入张量沿给定轴进行排序，输出排序好的数据，其形状和输入相同。
* `triangular_solve`：计算具有唯一解的线性方程组，其中系数方阵 A 为上（下）三角系数矩阵。若系数方阵 A 不可逆，则线性方程不可解。
* `argmax`：沿给定轴计算输入张量在该轴上的最大元素的索引值张量。

### 2. 激活函数（5 个）

* `sin`：对输入张量逐元素求正弦值。计算公式为$output = \sin(input)$。
* `cos`：对输入张量逐元素求余弦值。计算公式为$output = \cos(input)$。
* `tanh`：对输入张量逐元素求双曲正切值。计算公式为$output = \frac{e^x - e^{-x}}{e^x + e^{-x}}$。
* `atan2`：对输入张量 x、y 逐元素求原点(0,0)至点(x,y)的方位角，即与横坐标轴 x 轴的夹角，并将各个位置的输出元素保存到返回结果中。 计算公式如下：
$$
output = \begin{cases} \arctan(\frac{x}{y}) & y > 0 \\ \arctan(\frac{x}{y}) + \pi & x \ge 0, y < 0 \\ \arctan(\frac{x}{y}) - \pi & x < 0, y < 0 \\ +\frac{\pi}{2} & x > 0, y = 0 \\  - \frac{\pi}{2}& x < 0, y = 0 \\ undefined & x = 0, y = 0  \end{cases}
$$
* `clamp`：对输入张量逐元素进行裁剪，使得输出元素限制在指定的[min_value, max_value]范围内，并将各个位置的输出元素保存到返回结果中。计算公式为$output = \min(\max(input, min\_value), max\_value)$。

### 3. elementwise 一元算子（22 个）

* `abs`：对输入张量逐元素求绝对值，并将各个位置的输出元素保存到返回结果中。计算公式为$output = |input|$。
* `ceil`：对输入张量逐元素向上取整，并将各个位置的输出元素保存到返回结果中。 计算公式为$output = \lceil input \rceil$。
* `floor`：对输入张量逐元素向下取整，并将各个位置的输出元素保存到返回结果中。计算公式为$output = \lfloor input \rfloor$。
* `round`：对输入张量逐元素四舍五入到最接近的整数数值，并将各个位置的输出元素保存到返回结果中。
* `clz`：clz 是 count leading zeros 的缩写，主要用于对输入张量逐元素统计前导零的数目。即在二进制表达下，第一个 1 比特位前的 0 比特位的数目（均包含符号位），并将各个位置的统计数目保存到返回结果中。
* `population_count`：在二进制表达下，对输入张量逐元素统计 1 比特位的数目（包含符号位），并将各个位置的统计数目保存到返回结果中。
* `exp`：对输入张量逐元素求以自然数 e 为底的指数值，并将各个位置的输出元素保存到返回结果中。计算公式为$output = e^{input}$。
* `expm1`：对输入张量逐元素求以自然数 e 为底的指数值并减 1，并将各个位置的输出元素保存到返回结果中。计算公式为$output = e^{input} - 1$。
* `is_finite`：对输入张量逐元素判断是否 Finite，换而言之，是否既非+/-INF，亦非+/-NaN，并将各个位置的判断结果保存到返回结果中。计算公式为$output = (input \ne +\inf) \&\& (input \ne -\inf) \&\& (input \ne \text{nan})$。
* `log`：对输入张量逐元素求自然对数，并将各个位置的输出元素保存到返回结果中。计算公式为$output = \ln(input)$。
* `log1p`：对输入张量逐元素求加一的自然对数，并将各个位置的输出元素保存到返回结果中。计算公式为$output = \ln(input + 1)$。
* `logistic`：对输入张量逐元素求标准逻辑斯谛函数（又称 sigmoid 函数）值，并将各个位置的输出元素保存到返回结果中。计算公式为$output = \frac{1}{1 + e^{-input}}$。
* `sign`：对输入张量逐元素进行正负判断，并将各个位置的正负判断值保存到返回结果中。计算公式如下：
$$
output = \begin{cases}
         -1 & input \lt 0 \\
         -0 & input = -0 \\
         NaN & input = NaN \\
         +0 & input = +0 \\
         1 & input \gt 0
\end{cases}
$$
* `bitwise_not`：对输入张量逐元素按位取反，并将各个位置的输出元素保存到返回结果中。计算公式为$output =  \sim input$。
* `logical_not`：对输入张量逐元素进行逻辑非运算，并将各个位置的输出元素保存到返回结果中。计算公式为$output =  ! input$。
* `negate`：对输入张量逐元素取相反数，并将各个位置的输出元素保存到返回结果中。计算公式为$output =  - input$。
* `reciprocal`：对输入张量逐元素取倒数，并将各个位置的输出元素保存到返回结果中。计算公式为$output = \frac{1}{input}$。
* `imag`：对输入张量逐元素取复数的虚部数值，并将各个位置的输出元素保存到返回结果中。
* `real`：对输入张量逐元素取复数的实部数值，并将各个位置的输出元素保存到返回结果中。
* `rsqrt`：对输入张量逐元素取平方根的倒数，并将各个位置的输出元素保存到返回结果中。计算公式为$output = \frac{1}{\sqrt{input}}$。
* `sqrt`：对输入张量逐元素取平方根，并将各个位置的输出元素保存到返回结果中。计算公式为$output =\sqrt{input}$。
* `cbrt`：对输入张量逐元素取立方根，并将各个位置的输出元素保存到返回结果中。计算公式为$output =\sqrt[3]{input}$。

### 4. elementwise 二元算子（20 个）

* `add`：逐元素相加算子，对两个输入张量逐元素相加，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x + y$。
* `sub`：逐元素相减算子，对两个输入张量逐元素相减，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x - y$。
* `mul`：逐元素相乘算子，对两个输入张量逐元素相乘，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x \odot y$。
* `div`：逐元素相除算子，对两个输入张量逐元素相除，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x / y$。
* `complex`：给定实部输入张量和虚部输入张量，逐元素转换为复数，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x + yi$。
* `mod`：逐元素取模算子，对两个输入张量逐元素取模，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x \% y$。
* `remainder`：逐元素取余算子。不同于取模在取余数过程中尽可能让商向负无穷靠近，取余尽可能让商向零靠近。
* `arithmetic_right_shift`：逐元素算术右移算子，逐元素对第一个输入张量 x 右移 y 个位置，空出的高位用最高位（符号位）填补，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x \gg y$。
* `logical_right_shift`：逐元素逻辑右移算子，逐元素对第一个输入张量 x 右移 y 个位置，空出的高位用 0 填补，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x \ggg y$。
* `logical_left_shift`：逐元素逻辑左移算子，逐元素对第一个输入张量 x 左移 y 个位置，空出的低位用 0 填补，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x \ll y$。
* `bitwise_and`：逐元素按位与算子，对两个输入张量逐元素进行按位与运算，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x \& y$。
* `bitwise_or`：逐元素按位或算子，对两个输入张量逐元素进行按位或运算，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x | y$。
* `bitwise_xor`：逐元素按位异或算子，对两个输入张量逐元素进行按位异或运算，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x \verb!^! y$。
* `logical_and`：逐元素逻辑与算子，对两个输入张量逐元素进行逻辑与运算，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x \&\& y$。
* `logical_or`：逐元素逻辑或算子，对两个输入张量逐元素进行逻辑或运算，并将各个位置的输出元素保存到返回结果中。计算公式为$output = x || y$。
* `logical_xor`：逐元素逻辑异或算子，对两个输入张量逐元素进行逻辑异或运算，并将各个位置的输出元素保存到返回结果中。计算公式为$output = (x || y) \&\& !(x \&\& y)$。
* `maximum`：逐元素对比两个输入张量，并且把各个位置更大的元素保存到返回结果中。计算公式为$output = \max(x,y)$。
* `minimum`：逐元素对比两个输入张量，并且把各个位置更小的元素保存到返回结果中。计算公式为$output = \min(x,y)$。
* `pow`：逐元素计算输入张量 x 的 y 次幂的值，并且把各个位置的输出元素保存到返回结果中。计算公式为$output = x^y$。
* `table`：表查找操作，主要用于实现嵌入层(Embedding Layer)。

### 5. elementwise 三元算子（1 个）

* `select`：逐元素判断`condition`输入张量中的值，并且根据各个位置的判断结果把`true_value`输入张量或`false_value`输入张量中的元素保存到返回结果中。计算公式如下：
$$
output = condition ?  true\_value : false\_value
$$

### 6. 比较算子（3 个）

* `equal`：逐元素判断两个输入 Tensor 是否相等，并且把各个位置的判断结果保存到返回结果中。计算公式为$output = (x == y)$。
* `greater`：逐元素判断第一个输入 Tensor 是否大于第二个输入 Tensor，并且把各个位置的判断结果保存到返回结果中。计算公式为$output = (x > y)$。
* `greater_equal`：逐元素判断第一个输入 Tensor 是否大于等于第二个输入 Tensor，并且把各个位置的判断结果保存到返回结果中。计算公式为$output = (x \ge y)$。

### 7. reduction 算子（6 个）

* `reduce_all`：沿给定的坐标轴对输入张量元素进行逻辑与运算。
* `reduce_any`：沿给定的坐标轴对输入张量元素进行逻辑或运算。
* `reduce_max`：沿给定的坐标轴对输入张量元素取最大值。
* `reduce_min`：沿给定的坐标轴对输入张量元素取最小值。
* `reduce_product`：沿给定的坐标轴对输入张量元素进行规约求乘积操作。
* `reduce_sum`：沿给定的坐标轴对输入张量元素进行规约求和操作。

### 8. broadcast 算子（2 个）

* `broadcast`：将输入张量广播到指定的形状。
* `broadcast_in_dim`：将输入张量广播到指定的形状。相比 broadcast，broadcast_in_dim 通过给定轴映射列表，支持了广播维度不连续情况下的广播。

### 9. datalayout 变换算子（6 个）

* `concat`：沿给定的坐标轴将所有输入张量合并为一个张量。
* `pad`：根据`padding_mode`对输入张量进行扩充，扩充区域填充值`value`。
* `reshape`：在保持输入数据不变的情况下，改变张量的形状。
* `reverse`：沿给定的坐标轴反转输入张量中的数据，数据类型和形状均保持不变。
* `slice`：根据给定的起点（包含）、终点（不包含）和步长值，对输入张量进行切片，并将切片后的张量保存到返回结果中。
* `transpose`：根据给定重排轴列表，对输入张量进行数据重排。

### 10. 图像操作算子（1 个）

* `resize`：将输入图片通过指定插值方法调整为指定大小，输入图片应该是 4-D 张量，且形状为[N, C, H, W]，注意调整仅适用于 H、W 对应维度。

### 11. 集合[通信]算子（11 个）

* `gather`：根据指定的偏移量数组将多个操作数片段进行拼接。
* `scatter`：根据输入的`updates`张量和`indices`数组，对输入张量进行相应切片的更新。
* `select_and_scatter`：`select`和`scatter`的复合操作。即在给定窗口大小中从输入张量中选择目标元素，然后将数值`scatter`到输出张量中。

* `send`：发送指定操作数的值到另一个共享相同通道的`recv`指令。
* `recv`：从另一个共享相同通道的`send`指令中接收指定 shape 的数据。
* `all_gather`：将来自不同分布式节点的数据进行`concat`操作。
* `all_reduce`：将来自不同分布式节点的数据进行`reduce`操作。
* `reduce_scatter`：在执行完`all_reduce`操作后执行`scatter`操作。
* `all_to_all`：将数据从所有分布式节点发送到所有分布式节点。
* `after_all`：接收可变数量的令牌，生成单一令牌。主要用于控制特定操作的执行在一组操作完成之后进行。
* `collective_permute`：在不同分布式节点之间发送和接收数据。

### 12. 数据节点算子（3 个）

* `const`：将输入的常量数据保存到输出张量中。
* `identity`：拷贝输入张量中的所有元素到输出张量中，类似`copy`操作。
* `iota`：创建一个一维数组，其值从 0 开始并以 1 递增。

### 13. 类型转换算子（2 个）

* `cast`：将输入张量中的所有元素转换为指定类型并保存到输出张量中，亦可称为`convert`。
* `bitcast_convert`：在不改变底层存储的情况下，强制转换数据类型。若转换前后数据类型的字节大小不相同，则形状会改变。比如一个 shape 为[10]的 float32 类型数据被强制转换为 float16 类型后，其 shape 应为[10, 2]。

### 14. 量化操作算子（3 个）

* `quantize`：量化操作，即将浮点数量化为定点数，会生成`scale`（比例系数）和`zero_point`（零点偏移值）。
* `dequantize`：反量化操作，即根据`scale`和`zero_point`的值将定点数反量化为浮点数。
* `rescale`：使用给定的`multiplier`（系数）和`shift`（偏移）将量化值类型转换为其他定点数类型。如将`int8_t`量化值转为`int16_t`类型。

### 15. 随机数生成算子（3 个）

* `rng_bit_generator`：使用指定的随机数生成算法生成随机数，用于填充给定 shape 的张量。
* `rng_normal`：生成符合正态分布的随机数。
* `rng_uniform`：生成符合均匀分布的随机数。

### 16. 控制流算子（3 个）

* `if`：`if`条件算子。
* `case`：`switch-case`条件算子。
* `while`：`while`循环算子。

### 17. 动态 shape 算子（8 个）

* `dynamic_conv`：具有动态 shape 语义的卷积操作。
* `dynamic_gather`：具有动态 shape 语义的`gather`操作。
* `dynamic_iota`：具有动态 shape 语义的`iota`操作。
* `dynamic_pad`：具有动态 shape 语义的`padding`操作。
* `dynamic_reshape`：具有动态 shape 语义的`reshape`操作。
* `dynamic_slice`：具有动态 shape 语义的切片操作。
* `dynamic_update_slice`：具有动态 shape 语义的`update_slice`操作，即先进行切片操作在根据`update`输入张量对切片进行更新。
* `dynamic_broadcast_in_dim`：具有动态 shape 语义的`broadcast_in_dim`操作。

> 动态 shape 语义主要是将对应静态 shape 算子中编译时确定的属性输入替换为运行时可变的张量输入。

### 18. 爱因斯坦求和约定（2 个）

* `einsum`：二元爱因斯坦求和约定算子。
* `unary_einsum`：一元爱因斯坦求和约定算子。

### 19. 自定义算子（1 个）

* `custom_call`：`custom_call`主要用于调用设备相关的外部函数。
