跳转至

Tensor 的基本操作

Tensor 作为数组的基本操作

一些同学(比如……我)还没熟悉基本的 Tensor 操作,就去学习高阶 API 或者搭建神经网络,所以经常在 dim 参数上出问题,所以在打基础的阶段就要解决这些问题。

取值

打印出来的 Tensor 以很 Pythonic 的嵌套列表方式展示,但作为 C 程序员还是愿意将其理解为二维数组,比如下边这个 x,就是一个 \(3\times 4\) 的二维数组:

In [3]: x = torch.arange(12).reshape(3, 4)

In [4]: x
Out[4]:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

而取 x[0],则是取了 \(3\times 4\) 二维数组的一维,就是长度为 \(4\) 的一维数组

In [5]: x[0]
Out[5]: tensor([0, 1, 2, 3])

如果要取某一个值,可以使用 x[0, 1] 的方式,注意不鼓励用 x[0][1] 的方式,因为这种方式是先取一个一维数组 x[0],再从 x[0] 中取 x[0][1],在一些 corner case(比如使用 Tensor 下标操作) 会出现赋值失败的问题。

In [6]: x[0][1] # 错误示例
Out[6]: tensor(1)

In [7]: x[0, 1] 
Out[7]: tensor(1)

下面是一个 corner case

In [7]: a = torch.tensor((1, 2))

In [8]: b = torch.tensor((0, 2, 3))

In [9]: a
Out[9]: tensor([1, 2])

In [10]: b
Out[10]: tensor([0, 2, 3])

In [11]: x[a, b]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[11], line 1
----> 1 x[a, b]

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [2], [3]

In [12]: x.shape
Out[12]: torch.Size([3, 4])

In [13]: a = a[:, None]

In [14]: b = b[None, a]

In [15]: x[a, b]
Out[15]:
tensor([[[ 6],
         [11]]])

In [16]: x[a, b][1] = 99
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[16], line 1
----> 1 x[a, b][1] = 99

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [17]: x[a, b][0][0][1] = 99
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[17], line 1
----> 1 x[a, b][0][0][1] = 99

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [18]: x[a, b][0][0]
Out[18]: tensor([6])

In [19]: x[a, b][0][0][0] = 99

In [20]: x
Out[20]:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [21]: x[a, b][0, 0, 0] = 99

In [22]: x
Out[22]:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

最后发现,x 并未被改变。为了保证真实性我保留了所有中间信息。

顺带说一句 C-order 和 F-order,中文名分别为“行优先存储”和“列优先存储”。在 C 语言数组中,最后一个数字的元素是紧贴的,而 Fortran 中,第一个数字的元素是紧贴的。

torch.sumreduce 操作中的 dim 参数表示“按该维度求”。举个例子,对一个 NCHW 的四维数组(n_samples, channels, height, width)),torch.mean(x, dim=0) 表示对所有 samplechannel, height, width 位置上的数求均值,而 torch.mean(x, dim=1) 表示对通道求均值,

  • torch.mean(x, dim=0)[0, 1, 2] 是所有样本在第 \(0\) 个通道上,\((1, 2)\) 像素点的均值,
  • torch.mean(x, dim=1)[0, 1, 2] 是第 \(0\) 个样本在 \((1, 2)\) 号像素点, \(c\) 个通道的均值。

基本算数操作

作为一个张量计算库,算数操作是最基本的修养。PyTorch 中有一百来个算数操作算子。

逐元素操作算子

PyTorch 中,Tensor 的逐元素操作可分为函数式操作原位操作,在函数名中用 _ 后缀区分了二者,比如 tensor.abs() 是一个函数式操作,不改变 tensor 而将改变体现在返回值中,而 tensor.abs_() 是一个原位操作,既修改了原 tensor 又将这个修改后的值返回。

这里的操作都是逐元素进行的,元素之间没有关联。基本上每个 API 都有函数式和原位两个版本,也有作为 torch 模块的函数的过程化编程和作为 Tensor 的方法的面向对象式编程,其实面向过程和面向对象的唯一区别是第一个参数。每个算子一共有至少三种写法。

算子名 说明
abs 计算 input 中每个元素的绝对值.
absolute Alias for torch.abs()
acos inverse cosine 大家熟悉的 arccos input.
arccos Alias for torch.acos().
acosh inverse hyperbolic cosine of input.
arccosh Alias for torch.acosh().
add Adds other, scaled by alpha, to input.
addcdiv Performs the element-wise division of tensor1 by tensor2, multiply the result by the scalar value and add it to input.
addcmul Performs the element-wise multiplication of tensor1 by tensor2, multiply the result by the scalar value and add it to input.
angle Computes the element-wise angle (in radians) of the given input tensor.
asin Returns a new tensor with the arcsine of the elements of input.
arcsin Alias for torch.asin().
asinh Returns a new tensor with the inverse hyperbolic sine of the elements of input.
arcsinh Alias for torch.asinh().
atan Returns a new tensor with the arctangent of the elements of input.
arctan Alias for torch.atan().
atanh Returns a new tensor with the inverse hyperbolic tangent of the elements of input.
arctanh Alias for torch.atanh().
atan2 Element-wise arctangent of \text{input}{i} / \text{other}{i}inputi/otheri with consideration of the quadrant.
arctan2 Alias for torch.atan2().
bitwise_not Computes the bitwise NOT of the given input tensor.
bitwise_and Computes the bitwise AND of input and other.
bitwise_or Computes the bitwise OR of input and other.
bitwise_xor Computes the bitwise XOR of input and other.
bitwise_left_shift Computes the left arithmetic shift of input by other bits.
bitwise_right_shift Computes the right arithmetic shift of input by other bits.
ceil Returns a new tensor with the ceil of the elements of input, the smallest integer greater than or equal to each element.
clamp Clamps all elements in input into the range [ min, max ].
clip Alias for torch.clamp().
conj_physical Computes the element-wise conjugate of the given input tensor.
copysign Create a new floating-point tensor with the magnitude of input and the sign of other, elementwise.
cos Returns a new tensor with the cosine of the elements of input.
cosh Returns a new tensor with the hyperbolic cosine of the elements of input.
deg2rad Returns a new tensor with each of the elements of input converted from angles in degrees to radians.
div Divides each element of the input input by the corresponding element of other.
divide Alias for torch.div().
digamma Alias for torch.special.digamma().
erf Alias for torch.special.erf().
erfc Alias for torch.special.erfc().
erfinv Alias for torch.special.erfinv().
exp Returns a new tensor with the exponential of the elements of the input tensor input.
exp2 Alias for torch.special.exp2().
expm1 Alias for torch.special.expm1().
fake_quantize_per_channel_affine Returns a new tensor with the data in input fake quantized per channel using scale, zero_point, quant_min and quant_max, across the channel specified by axis.
fake_quantize_per_tensor_affine Returns a new tensor with the data in input fake quantized using scale, zero_point, quant_min and quant_max.
fix Alias for torch.trunc()
float_power Raises input to the power of exponent, elementwise, in double precision.
floor Returns a new tensor with the floor of the elements of input, the largest integer less than or equal to each element.
floor_divide
fmod Applies C++’s std::fmod entrywise.
frac Computes the fractional portion of each element in input.
frexp Decomposes input into mantissa and exponent tensors such that \(\text{input} = \text{mantissa} \times 2^{\text{exponent}}\).
gradient Estimates the gradient of a function \(g : \mathbb{R}^n \rightarrow \mathbb{R}\) in one or more dimensions using the second-order accurate central differences method.
imag Returns a new tensor containing imaginary values of the self tensor.
ldexp Multiplies input by \(2^\text{other}\).
lerp Does a linear interpolation of two tensors start (given by input) and end based on a scalar or tensor weight and returns the resulting out tensor.
lgamma Computes the natural logarithm of the absolute value of the gamma function on input.
log Returns a new tensor with the natural logarithm of the elements of input.
log10 Returns a new tensor with the logarithm to the base 10 of the elements of input.
log1p Returns a new tensor with the natural logarithm of (1 + input).
log2 Returns a new tensor with the logarithm to the base 2 of the elements of input.
logaddexp Logarithm of the sum of exponentiations of the inputs.
logaddexp2 Logarithm of the sum of exponentiations of the inputs in base-2.
logical_and Computes the element-wise logical AND of the given input tensors.
logical_not Computes the element-wise logical NOT of the given input tensor.
logical_or Computes the element-wise logical OR of the given input tensors.
logical_xor Computes the element-wise logical XOR of the given input tensors.
logit Alias for torch.special.logit().
hypot Given the legs of a right triangle, return its hypotenuse.
i0 Alias for torch.special.i0().
igamma Alias for torch.special.gammainc().
igammac Alias for torch.special.gammaincc().
mul Multiplies input by other.
multiply Alias for torch.mul().
mvlgamma Alias for torch.special.multigammaln().
nan_to_num Replaces NaN, positive infinity, and negative infinity values in input with the values specified by nan, posinf, and neginf, respectively.
neg Returns a new tensor with the negative of the elements of input.
negative Alias for torch.neg()
nextafter Return the next floating-point value after input towards other, elementwise.
polygamma Alias for torch.special.polygamma().
positive Returns input.
pow Takes the power of each element in input with exponent and returns a tensor with the result.
quantized_batch_norm Applies batch normalization on a 4D (NCHW) quantized tensor.
quantized_max_pool1d Applies a 1D max pooling over an input quantized tensor composed of several input planes.
quantized_max_pool2d Applies a 2D max pooling over an input quantized tensor composed of several input planes.
rad2deg Returns a new tensor with each of the elements of input converted from angles in radians to degrees.
real Returns a new tensor containing real values of the self tensor.
reciprocal Returns a new tensor with the reciprocal of the elements of input
remainder Computes Python’s modulus operation entrywise.
round Rounds elements of input to the nearest integer.
rsqrt Returns a new tensor with the reciprocal of the square-root of each of the elements of input.
sigmoid Alias for torch.special.expit().
sign Returns a new tensor with the signs of the elements of input.
sgn This function is an extension of torch.sign() to complex tensors.
signbit Tests if each element of input has its sign bit set (is less than zero) or not.
sin Returns a new tensor with the sine of the elements of input.
sinc Alias for torch.special.sinc().
sinh Returns a new tensor with the hyperbolic sine of the elements of input.
sqrt Returns a new tensor with the square-root of the elements of input.
square Returns a new tensor with the square of the elements of input.
sub Subtracts other, scaled by alpha, from input.
subtract Alias for torch.sub().
tan Returns a new tensor with the tangent of the elements of input.
tanh Returns a new tensor with the hyperbolic tangent of the elements of input.
true_divide Alias for torch.div() with rounding_mode=None.
trunc Returns a new tensor with the truncated integer values of the elements of input.
xlogy Alias for torch.special.xlogy().

除了上述函数,PyTorch 还为我们重载了 Tensor 的很多运算符,这些运算符都符合 Python 的语义。PyTorch 的开发者为了 be Pythonic 还是费了相当大心思的,这也是 PyTorch 能很好融入 Python 生态的原因。

运算符 说明
+ add
- sub
* mul
/ div
// truediv
@ matmul

规约算子

规约是函数式编程中的重要概念。

算子名 说明
argmax Returns the indices of the maximum value of all elements in the input tensor.
argmin Returns the indices of the minimum value(s) of the flattened tensor or along a dimension
amax Returns the maximum value of each slice of the input tensor in the given dimension(s) dim.
amin Returns the minimum value of each slice of the input tensor in the given dimension(s) dim.
aminmax Computes the minimum and maximum values of the input tensor.
all Tests if all elements in input evaluate to True.
any Tests if any element in input evaluates to True.
max Returns the maximum value of all elements in the input tensor.
min Returns the minimum value of all elements in the input tensor.
dist Returns the p-norm of (input - other)
logsumexp Returns the log of summed exponentials of each row of the input tensor in the given dimension dim.
mean Returns the mean value of all elements in the input tensor.
nanmean Computes the mean of all non-NaN elements along the specified dimensions.
median Returns the median of the values in input.
nanmedian Returns the median of the values in input, ignoring NaN values.
mode Returns a namedtuple (values, indices) where values is the mode value of each row of the input tensor in the given dimension dim, i.e. a value which appears most often in that row, and indices is the index location of each mode value found.
norm Returns the matrix norm or vector norm of a given tensor.
nansum Returns the sum of all elements, treating Not a Numbers (NaNs) as zero.
prod Returns the product of all elements in the input tensor.
quantile Computes the q-th quantiles of each row of the input tensor along the dimension dim.
nanquantile This is a variant of torch.quantile() that “ignores” NaN values, computing the quantiles q as if NaN values in input did not exist.
std If unbiased is True, Bessel’s correction will be used.
std_mean If unbiased is True, Bessel’s correction will be used to calculate the standard deviation.
sum Returns the sum of all elements in the input tensor.
unique Returns the unique elements of the input tensor.
unique_consecutive Eliminates all but the first element from every consecutive group of equivalent elements.
var If unbiased is True, Bessel’s correction will be used.
var_mean If unbiased is True, Bessel’s correction will be used to calculate the variance.
count_nonzero Counts the number of non-zero values in the tensor input along the given dim.

比较算子

allclose This function checks if all input and other satisfy the condition:
argsort Returns the indices that sort a tensor along a given dimension in ascending order by value.
eq Computes element-wise equality
equal True if two tensors have the same size and elements, False otherwise.
ge Computes \text{input} \geq \text{other}input≥other element-wise.
greater_equal Alias for torch.ge().
gt Computes \text{input} > \text{other}input>other element-wise.
greater Alias for torch.gt().
isclose Returns a new tensor with boolean elements representing if each element of input is “close” to the corresponding element of other.
isfinite Returns a new tensor with boolean elements representing if each element is finite or not.
isin Tests if each element of elements is in test_elements.
isinf Tests if each element of input is infinite (positive or negative infinity) or not.
isposinf Tests if each element of input is positive infinity or not.
isneginf Tests if each element of input is negative infinity or not.
isnan Returns a new tensor with boolean elements representing if each element of input is NaN or not.
isreal Returns a new tensor with boolean elements representing if each element of input is real-valued or not.
kthvalue Returns a namedtuple (values, indices) where values is the k th smallest element of each row of the input tensor in the given dimension dim.
le Computes \text{input} \leq \text{other}input≤other element-wise.
less_equal Alias for torch.le().
lt Computes \text{input} < \text{other}input<other element-wise.
less Alias for torch.lt().
maximum Computes the element-wise maximum of input and other.
minimum Computes the element-wise minimum of input and other.
fmax Computes the element-wise maximum of input and other.
fmin Computes the element-wise minimum of input and other.
ne Computes \text{input} \neq \text{other}input=other element-wise.
not_equal Alias for torch.ne().
sort Sorts the elements of the input tensor along a given dimension in ascending order by value.
topk Returns the k largest elements of the given input tensor along a given dimension.
msort Sorts the elements of the input tensor along its first dimension in ascending order by value.

谱算子 Spectral Ops

stft Short-time Fourier transform (STFT).
istft Inverse short time Fourier Transform.
bartlett_window Bartlett window function.
blackman_window Blackman window function.
hamming_window Hamming window function.
hann_window Hann window function.
kaiser_window Computes the Kaiser window with window length window_length and shape parameter beta.

杂类数学算子

atleast_1d Returns a 1-dimensional view of each input tensor with zero dimensions.
atleast_2d Returns a 2-dimensional view of each input tensor with zero dimensions.
atleast_3d Returns a 3-dimensional view of each input tensor with zero dimensions.
bincount Count the frequency of each value in an array of non-negative ints.
block_diag Create a block diagonal matrix from provided tensors.
broadcast_tensors Broadcasts the given tensors according to Broadcasting semantics.
broadcast_to Broadcasts input to the shape shape.
broadcast_shapes Similar to broadcast_tensors() but for shapes.
bucketize Returns the indices of the buckets to which each value in the input belongs, where the boundaries of the buckets are set by boundaries.
cartesian_prod Do cartesian product of the given sequence of tensors.
cdist Computes batched the p-norm distance between each pair of the two collections of row vectors.
clone Returns a copy of input.
combinations Compute combinations of length rr of the given tensor.
corrcoef Estimates the Pearson product-moment correlation coefficient matrix of the variables given by the input matrix, where rows are the variables and columns are the observations.
cov Estimates the covariance matrix of the variables given by the input matrix, where rows are the variables and columns are the observations.
cross Returns the cross product of vectors in dimension dim of input and other.
cummax Returns a namedtuple (values, indices) where values is the cumulative maximum of elements of input in the dimension dim.
cummin Returns a namedtuple (values, indices) where values is the cumulative minimum of elements of input in the dimension dim.
cumprod Returns the cumulative product of elements of input in the dimension dim.
cumsum Returns the cumulative sum of elements of input in the dimension dim.
diag If input is a vector (1-D tensor), then returns a 2-D square tensor
diag_embed Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by input.
diagflat If input is a vector (1-D tensor), then returns a 2-D square tensor
diagonal Returns a partial view of input with the its diagonal elements with respect to dim1 and dim2 appended as a dimension at the end of the shape.
diff Computes the n-th forward difference along the given dimension.
einsum Sums the product of the elements of the input operands along dimensions specified using a notation based on the Einstein summation convention.
flatten Flattens input by reshaping it into a one-dimensional tensor.
flip Reverse the order of a n-D tensor along given axis in dims.
fliplr Flip tensor in the left/right direction, returning a new tensor.
flipud Flip tensor in the up/down direction, returning a new tensor.
kron Computes the Kronecker product, denoted by \otimes⊗, of input and other.
rot90 Rotate a n-D tensor by 90 degrees in the plane specified by dims axis.
gcd Computes the element-wise greatest common divisor (GCD) of input and other.
histc Computes the histogram of a tensor.
histogram Computes a histogram of the values in a tensor.
histogramdd Computes a multi-dimensional histogram of the values in a tensor.
meshgrid Creates grids of coordinates specified by the 1D inputs in attr:tensors.
lcm Computes the element-wise least common multiple (LCM) of input and other.
logcumsumexp Returns the logarithm of the cumulative summation of the exponentiation of elements of input in the dimension dim.
ravel Return a contiguous flattened tensor.
renorm Returns a tensor where each sub-tensor of input along dimension dim is normalized such that the p-norm of the sub-tensor is lower than the value maxnorm
repeat_interleave Repeat elements of a tensor.
roll Roll the tensor input along the given dimension(s).
searchsorted Find the indices from the innermost dimension of sorted_sequence such that, if the corresponding values in values were inserted before the indices, when sorted, the order of the corresponding innermost dimension within sorted_sequence would be preserved.
tensordot Returns a contraction of a and b over multiple dimensions.
trace Returns the sum of the elements of the diagonal of the input 2-D matrix.
tril Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
tril_indices Returns the indices of the lower triangular part of a row-by- col matrix in a 2-by-N Tensor, where the first row contains row coordinates of all indices and the second row contains column coordinates.
triu Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
triu_indices Returns the indices of the upper triangular part of a row by col matrix in a 2-by-N Tensor, where the first row contains row coordinates of all indices and the second row contains column coordinates.
vander Generates a Vandermonde matrix.
view_as_real Returns a view of input as a real tensor.
view_as_complex Returns a view of input as a complex tensor.
resolve_conj Returns a new tensor with materialized conjugation if input’s conjugate bit is set to True, else returns input.
resolve_neg Returns a new tensor with materialized negation if input’s negative bit is set to True, else returns input.

线性代数算子 BLAS

addbmm Performs a batch matrix-matrix product of matrices stored in batch1 and batch2, with a reduced add step (all matrix multiplications get accumulated along the first dimension).
addmm Performs a matrix multiplication of the matrices mat1 and mat2.
addmv Performs a matrix-vector product of the matrix mat and the vector vec.
addr Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix input.
baddbmm Performs a batch matrix-matrix product of matrices in batch1 and batch2.
bmm Performs a batch matrix-matrix product of matrices stored in input and mat2.
chain_matmul Returns the matrix product of the NN 2-D tensors.
cholesky Computes the Cholesky decomposition of a symmetric positive-definite matrix AA or for batches of symmetric positive-definite matrices.
cholesky_inverse Computes the inverse of a symmetric positive-definite matrix AA using its Cholesky factor uu: returns matrix inv.
cholesky_solve Solves a linear system of equations with a positive semidefinite matrix to be inverted given its Cholesky factor matrix uu.
dot Computes the dot product of two 1D tensors.
eig Computes the eigenvalues and eigenvectors of a real square matrix.
geqrf This is a low-level function for calling LAPACK’s geqrf directly.
ger Alias of torch.outer().
inner Computes the dot product for 1D tensors.
inverse Alias for torch.linalg.inv()
det Alias for torch.linalg.det()
logdet Calculates log determinant of a square matrix or batches of square matrices.
slogdet Alias for torch.linalg.slogdet()
lstsq Computes the solution to the least squares and least norm problems for a full rank matrix AA of size (m \times n)(m×n) and a matrix BB of size (m \times k)(m×k).
lu Computes the LU factorization of a matrix or batches of matrices A.
lu_solve Returns the LU solve of the linear system Ax = bA**x=b using the partially pivoted LU factorization of A from torch.lu().
lu_unpack Unpacks the data and pivots from a LU factorization of a tensor into tensors L and U and a permutation tensor P such that LU_data, LU_pivots = (P @ L @ U).lu().
matmul Matrix product of two tensors.
matrix_power Alias for torch.linalg.matrix_power()
matrix_rank Returns the numerical rank of a 2-D tensor.
matrix_exp Alias for torch.linalg.matrix_exp().
mm Performs a matrix multiplication of the matrices input and mat2.
mv Performs a matrix-vector product of the matrix input and the vector vec.
orgqr Alias for torch.linalg.householder_product().
ormqr Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix.
outer Outer product of input and vec2.
pinverse Alias for torch.linalg.pinv()
qr Computes the QR decomposition of a matrix or a batch of matrices input, and returns a namedtuple (Q, R) of tensors such that \text{input} = Q Rinput=QR with QQ being an orthogonal matrix or batch of orthogonal matrices and RR being an upper triangular matrix or batch of upper triangular matrices.
svd Computes the singular value decomposition of either a matrix or batch of matrices input.
svd_lowrank Return the singular value decomposition (U, S, V) of a matrix, batches of matrices, or a sparse matrix AA such that A \approx U diag(S) V^TAUdia**g(S)V**T.
pca_lowrank Performs linear Principal Component Analysis (PCA) on a low-rank matrix, batches of such matrices, or sparse matrix.
symeig This function returns eigenvalues and eigenvectors of a real symmetric or complex Hermitian matrix input or a batch thereof, represented by a namedtuple (eigenvalues, eigenvectors).
lobpcg Find the k largest (or smallest) eigenvalues and the corresponding eigenvectors of a symmetric positive definite generalized eigenvalue problem using matrix-free LOBPCG methods.
trapz Alias for torch.trapezoid().
trapezoid Computes the trapezoidal rule along dim.
cumulative_trapezoid Cumulatively computes the trapezoidal rule along dim.
triangular_solve Solves a system of equations with a square upper or lower triangular invertible matrix AA and multiple right-hand sides bb.
vdot Computes the dot product of two 1D tensors.

下标操作

随机操作

序列化与反序列化

杂项

compiled_with_cxx11_abi Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1
result_type Returns the torch.dtype that would result from performing an arithmetic operation on the provided input tensors.
can_cast Determines if a type conversion is allowed under PyTorch casting rules described in the type promotion documentation.
promote_types Returns the torch.dtype with the smallest size and scalar kind that is not smaller nor of lower kind than either type1 or type2.
use_deterministic_algorithms Sets whether PyTorch operations must use “deterministic” algorithms.
are_deterministic_algorithms_enabled Returns True if the global deterministic flag is turned on.
is_deterministic_algorithms_warn_only_enabled Returns True if the global deterministic flag is set to warn only.
set_deterministic_debug_mode Sets the debug mode for deterministic operations.
get_deterministic_debug_mode Returns the current value of the debug mode for deterministic operations.
set_float32_matmul_precision Sets the internal precision of float32 matrix multiplications.
get_float32_matmul_precision Returns the current value of float32 matrix multiplication precision.
set_warn_always When this flag is False (default) then some PyTorch warnings may only appear once per process.
is_warn_always_enabled Returns True if the global warn_always flag is turned on.
_assert A wrapper around Python’s assert which is symbolically traceable.

说明

PyTorch 的算子大多为异步惰性求值,尤其是 GPU 版本。

算子只定义了入口,真正的实现依赖于不同的后端,常用的有 cpucudamkldnnnpu 等等。

知道这些的意义是,不用管这些,爱怎么编程就怎么编就可以了,PyTorch 引擎会帮你搞定一切的。


最后更新: 2023-08-28