nn.gemm
Performs General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
A' = transpose(A) if transA else A
B' = transpose(B) if transB else B
Compute Y = alpha * A' * B' + beta * C
, where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). A
will be transposed before doing the computation if attribute transA
is true
, same for B
and transB
.
Args
A
(Tensor<T>
) - Input tensor A. The shape ofA
should be (M, K) iftransA
isfalse
, or (K, M) iftransA
istrue
.B
(Tensor<T>
) - Input tensor B. The shape ofB
should be (K, N) iftransB
isfalse
, or (N, K) iftransB
istrue
.C
(Option<Tensor<T>>
) - Optional input tensor C. The shape of C should be unidirectional broadcastable to (M, N).alpha
(Option<T>
) - Optional scalar multiplier for the product of input tensorsA * B
.beta
(Option<T>
) - Optional scalar multiplier for input tensorC
.transA
(bool
) - WhetherA
should be transposed.transB
(bool
) - WhetherB
should be transposed.
Returns
A Tensor<T>
of shape (M, N).
Examples
Last updated