nn.gemm
fn gemm(
A: Tensor<T>,
B: Tensor<T>,
C: Option<Tensor<T>>,
alpha: Option<T>,
beta: Option<T>,
transA: bool,
transB: bool
) -> Tensor<T>;Performs General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
A' = transpose(A) if transA else A
B' = transpose(B) if transB else B
Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). A will be transposed before doing the computation if attribute transA is true, same for B and transB.
Args
A(Tensor<T>) - Input tensor A. The shape ofAshould be (M, K) iftransAisfalse, or (K, M) iftransAistrue.B(Tensor<T>) - Input tensor B. The shape ofBshould be (K, N) iftransBisfalse, or (N, K) iftransBistrue.C(Option<Tensor<T>>) - Optional input tensor C. The shape of C should be unidirectional broadcastable to (M, N).alpha(Option<T>) - Optional scalar multiplier for the product of input tensorsA * B.beta(Option<T>) - Optional scalar multiplier for input tensorC.transA(bool) - WhetherAshould be transposed.transB(bool) - WhetherBshould be transposed.
Returns
A Tensor<T> of shape (M, N).
Examples
Last updated