expand_dims作用:给定张量“ input”,此操作将在“ input”形状的尺寸索引“ axis”处插入尺寸为1的尺寸。 尺寸索引“轴”从零开始; 如果为“ axis”指定负数,则从末尾开始算起。
如果要将批次尺寸添加到单个元素,此操作很有用。 例如,如果您有一个形状为[[height,width,channels]`的图像,则可以将其与具有`expand_dims(image,0)`的1张图像一起批处理,这将使形状为[[1,height ,width,channels]。
# ‘t‘ is a tensor of shape [2]
tf.shape(tf.expand_dims(t, 0)) # [1, 2]
tf.shape(tf.expand_dims(t, 1)) # [2, 1]
tf.shape(tf.expand_dims(t, -1)) # [2, 1]
# ‘t2‘ is a tensor of shape [2, 3, 5]
tf.shape(tf.expand_dims(t2, 0)) # [1, 2, 3, 5]
tf.shape(tf.expand_dims(t2, 2)) # [2, 3, 1, 5]
tf.shape(tf.expand_dims(t2, 3)) # [2, 3, 5, 1]
```
This operation requires that:
`-1-input.dims() <= dim <= input.dims()`
This operation is related to `squeeze()`, which removes dimensions of
size 1.
Args:
input: A `Tensor`.
axis: 0-D (scalar). Specifies the dimension index at which to
expand the shape of `input`. Must be in the range
`[-rank(input) - 1, rank(input)]`.
name: The name of the output `Tensor`.
dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
Returns:
A `Tensor` with the same data as `input`, but its shape has an additional
dimension of size 1 added.
Raises:
ValueError: if both `dim` and `axis` are specified.
bert中源码:
# 该函数默认输入的形状为【batch_size, seq_length, input_num】 # 如果输入为2D的【batch_size, seq_length】,则扩展到【batch_size, seq_length, 1】 if input_ids.shape.ndims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1])
原文:https://www.cnblogs.com/nxf-rabbit75/p/12095669.html