linear_classifier.predict
fn predict(classifier: LinearClassifier<T>, X: Tensor<T>) -> Tensor<T>;
Linear Classifier. Performs the linear classification.
Args
self
: LinearClassifier - A LinearClassifier object.X
: Input 2D tensor.
Returns
Tensor containing the linear classification evaluation of the input X.
Type Constraints
LinearClassifier
and X
must be fixed points
Examples
use orion::numbers::FP16x16;
use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor};
use orion::operators::ml::linear::linear_classifier::{
LinearClassifierTrait, POST_TRANSFORM, LinearClassifier
};
fn linear_classifier_helper(
post_transform: POST_TRANSFORM
) -> (LinearClassifier<FP16x16>, Tensor<FP16x16>) {
let classlabels: Span<usize> = array![0, 1, 2].span();
let classlabels = Option::Some(classlabels);
let classlabels_strings: Option<Span<FP16x16>> = Option::None;
let coefficients: Span<FP16x16> = array![
FP16x16 { mag: 38011, sign: true },
FP16x16 { mag: 19005, sign: true },
FP16x16 { mag: 5898, sign: true },
FP16x16 { mag: 38011, sign: false },
FP16x16 { mag: 19005, sign: false },
FP16x16 { mag: 5898, sign: false },
]
.span();
let intercepts: Span<FP16x16> = array![
FP16x16 { mag: 176947, sign: false },
FP16x16 { mag: 176947, sign: true },
FP16x16 { mag: 32768, sign: false },
]
.span();
let intercepts = Option::Some(intercepts);
let multi_class: usize = 0;
let mut classifier: LinearClassifier<FP16x16> = LinearClassifier {
classlabels,
coefficients,
intercepts,
multi_class,
post_transform
};
let mut X: Tensor<FP16x16> = TensorTrait::new(
array![3, 2].span(),
array![
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 65536, sign: false },
FP16x16 { mag: 131072, sign: false },
FP16x16 { mag: 196608, sign: false },
FP16x16 { mag: 262144, sign: false },
FP16x16 { mag: 327680, sign: false },
]
.span()
);
(classifier, X)
}
fn linear_classifier_multi_softmax() -> (Span<usize>, Tensor<FP16x16>) {
let (mut classifier, X) = linear_classifier_helper(POST_TRANSFORM::SOFTMAX);
let (labels, mut scores) = LinearClassifierTrait::predict(classifier, X);
(labels, scores)
}
>>>
([0, 2, 2],
[
[0.852656, 0.009192, 0.138152],
[0.318722, 0.05216, 0.629118],
[0.036323, 0.090237, 0.87344]
])
Last updated