LogoLogo
Actions SDKGiza CLIDatasetsAgents
main
main
  • 👋Welcome
    • Orion
    • Why Validity ML?
  • 🧱Framework
    • Get Started
    • Contribute
    • Compatibility
    • Numbers
      • Fixed Point
        • fp.new
        • fp.new_unscaled
        • fp.from_felt
        • fp.abs
        • fp.ceil
        • fp.floor
        • fp.exp
        • fp.exp2
        • fp.log
        • fp.log2
        • fp.log10
        • fp.pow
        • fp.round
        • fp.sqrt
        • fp.sin
        • fp.atan
        • fp.sign
      • Complex Number
        • complex.acos
        • complex.acosh
        • complex.arg
        • complex.asin
        • complex.asinh
        • complex.atan
        • complex.atanh
        • complex.conjugate
        • complex.cos
        • complex.cosh
        • complex.exp
        • complex.exp2
        • complex.from_polar
        • complex.img
        • complex.ln
        • complex.log2
        • complex.log10
        • complex.mag
        • complex.new
        • complex.one
        • complex.pow
        • complex.real
        • complex.reciprocal
        • complex.sin
        • complex.sinh
        • complex.sqrt
        • complex.tan
        • complex.tanh
        • complex.to_polar
        • complex.zero
    • Operators
      • Tensor
        • tensor.new
        • tensor.at
        • tensor.min_in_tensor
        • tensor.min
        • tensor.max_in_tensor
        • tensor.max
        • tensor.stride
        • tensor.ravel_index
        • tensor.unravel_index
        • tensor.reshape
        • tensor.transpose
        • tensor.reduce_sum
        • tensor.argmax
        • tensor.argmin
        • tensor.matmul
        • tensor.exp
        • tensor.log
        • tensor.equal
        • tensor.greater
        • tensor.greater_equal
        • tensor.less
        • tensor.less_equal
        • tensor.abs
        • tensor.neg
        • tensor.ceil
        • tensor.cumsum
        • tensor.sin
        • tensor.cos
        • tensor.asin
        • tensor.flatten
        • tensor.sinh
        • tensor.asinh
        • tensor.cosh
        • tensor.acosh
        • tensor.tanh
        • tensor.atan
        • tensor.acos
        • tensor.sqrt
        • tensor.or
        • tensor.xor
        • tensor.onehot
        • tensor.slice
        • tensor.concat
        • tensor.gather
        • tensor.quantize_linear
        • tensor.dequantize_linear
        • tensor.qlinear_add
        • tensor.qlinear_mul
        • tensor.qlinear_matmul
        • tensor.qlinear_concat
        • tensor.qlinear_leakyrelu
        • tensor.nonzero
        • tensor.squeeze
        • tensor.unsqueeze
        • tensor.sign
        • tensor.clip
        • tensor.identity
        • tensor.and
        • tensor.where
        • tensor.bitwise_and
        • tensor.bitwise_xor
        • tensor.bitwise_or
        • tensor.resize
        • tensor.round
        • tensor.scatter
        • tensor.array_feature_extractor
        • tensor.binarizer
        • tensor.reduce_sum_square
        • tensor.reduce_l2
        • tensor.reduce_l1
        • tensor.reduce_prod
        • tensor.gather_elements
        • tensor.gather_nd
        • tensor.reduce_min
        • tensor.shrink
        • tensor.reduce_mean
        • tensor.pow
        • tensor.is_nan
        • tensor.is_inf
        • tensor.not
        • tensor.erf
        • tensor.reduce_log_sum
        • tensor.reduce_log_sum_exp
        • tensor.unique
        • tensor.compress
        • tensor.layer_normalization
        • tensor.scatter_nd
        • tensor.dynamic_quantize_linear
        • tensor.optional
        • tensor.reverse_sequence
        • tensor.split_to_sequence
        • tensor.range
        • tensor.hann_window
        • tensor.hamming_window
        • tensor.blackman_window
        • tensor.random_uniform_like
        • tensor.label_encoder
      • Neural Network
        • nn.relu
        • nn.leaky_relu
        • nn.sigmoid
        • nn.softmax
        • nn.softmax_zero
        • nn.logsoftmax
        • nn.softsign
        • nn.softplus
        • nn.linear
        • nn.hard_sigmoid
        • nn.thresholded_relu
        • nn.gemm
        • nn.grid_sample
        • nn.col2im
        • nn.conv_transpose
        • nn.conv
        • nn.depth_to_space
        • nn.space_to_depth
      • Machine Learning
        • Tree Ensemble Classifier
          • tree_ensemble_classifier.predict
        • Tree Ensemble Regressor
          • tree_ensemble_regressor.predict
        • Linear Classifier
          • linear_classifier.predict
        • Linear Regressor
          • linear_regressor.predict
        • SVM Regressor
          • svm_regressor.predict
        • SVM Classifier
          • svm_classifier.predict
        • Sequence
          • sequence.sequence_construct
          • sequence.sequence_empty
          • sequence.sequence_length
          • sequence.sequence_at
          • sequence.sequence_empty
          • sequence.sequence_erase
          • sequence.sequence_insert
          • sequence.concat_from_sequence
        • Normalizer
          • normalize.predict
  • 🏛️Hub
    • Models
    • Spaces
  • 🧑‍🎓Academy
    • Tutorials
      • MNIST Classification with Orion
      • Implement new operators in Orion
      • Verifiable Linear Regression Model
      • Verifiable Support Vector Machine
      • Verifiable Principal Components Analysis
      • Provable MLR: Forecasting AAVE's Lifetime Repayments
Powered by GitBook
On this page
  • Args
  • Returns
  • Type Constraints
  • Examples

Was this helpful?

Edit on GitHub
  1. Framework
  2. Operators
  3. Machine Learning
  4. Tree Ensemble Classifier

tree_ensemble_classifier.predict

   fn predict(classifier: TreeEnsembleClassifier<T>, X: Tensor<T>) -> (Span<usize>, MutMatrix::<T>);

Tree Ensemble classifier. Returns the top class for each of N inputs.

Args

  • self: TreeEnsembleClassifier - A TreeEnsembleClassifier object.

  • X: Input 2D tensor.

Returns

  • N Top class for each point

  • The class score Matrix for each class, for each point.

Type Constraints

TreeEnsembleClassifier and X must be fixed points

Examples

use orion::numbers::FP16x16;
use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor};
use orion::operators::ml::{NODE_MODES, TreeEnsembleAttributes, TreeEnsemble};
use orion::operators::ml::{
    TreeEnsembleClassifier, POST_TRANSFORM, TreeEnsembleClassifierTrait
};
use orion::operators::matrix::{MutMatrix, MutMatrixImpl};

fn tree_ensemble_classifier_helper(
   post_transform: POST_TRANSFORM
) -> (TreeEnsembleClassifier<FP16x16>, Tensor<FP16x16>) {
   let class_ids: Span<usize> = array![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
       .span();

   let class_nodeids: Span<usize> = array![2, 2, 2, 3, 3, 3, 4, 4, 4, 1, 1, 1, 3, 3, 3, 4, 4, 4]
       .span();

   let class_treeids: Span<usize> = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
       .span();

   let class_weights: Span<FP16x16> = array![
       FP16x16 { mag: 30583, sign: false },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 2185, sign: false },
       FP16x16 { mag: 13107, sign: false },
       FP16x16 { mag: 15729, sign: false },
       FP16x16 { mag: 3932, sign: false },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 32768, sign: false },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 32768, sign: false },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 29491, sign: false },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 3277, sign: false },
       FP16x16 { mag: 6746, sign: false },
       FP16x16 { mag: 12529, sign: false },
       FP16x16 { mag: 13493, sign: false },
   ]
       .span();

   let classlabels: Span<usize> = array![0, 1, 2].span();

   let nodes_falsenodeids: Span<usize> = array![4, 3, 0, 0, 0, 2, 0, 4, 0, 0].span();

   let nodes_featureids: Span<usize> = array![1, 0, 0, 0, 0, 1, 0, 0, 0, 0].span();

   let nodes_missing_value_tracks_true: Span<usize> = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span();

   let nodes_modes: Span<NODE_MODES> = array![
       NODE_MODES::BRANCH_LEQ,
       NODE_MODES::BRANCH_LEQ,
       NODE_MODES::LEAF,
       NODE_MODES::LEAF,
       NODE_MODES::LEAF,
       NODE_MODES::BRANCH_LEQ,
       NODE_MODES::LEAF,
       NODE_MODES::BRANCH_LEQ,
       NODE_MODES::LEAF,
       NODE_MODES::LEAF,
   ]
       .span();

   let nodes_nodeids: Span<usize> = array![0, 1, 2, 3, 4, 0, 1, 2, 3, 4].span();

   let nodes_treeids: Span<usize> = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1].span();

   let nodes_truenodeids: Span<usize> = array![1, 2, 0, 0, 0, 1, 0, 3, 0, 0].span();

   let nodes_values: Span<FP16x16> = array![
       FP16x16 { mag: 81892, sign: false },
       FP16x16 { mag: 19992, sign: true },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 110300, sign: true },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 44245, sign: true },
       FP16x16 { mag: 0, sign: false },
       FP16x16 { mag: 0, sign: false },
   ]
       .span();

   let tree_ids: Span<usize> = array![0, 1].span();

   let mut root_index: Felt252Dict<usize> = Default::default();
   root_index.insert(0, 0);
   root_index.insert(1, 5);

   let mut node_index: Felt252Dict<usize> = Default::default();
   node_index
       .insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0);
   node_index
       .insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1);
   node_index
       .insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2);
   node_index
       .insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3);
   node_index
       .insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4);
   node_index
       .insert(1089549915800264549621536909767699778745926517555586332772759280702396009108, 5);
   node_index
       .insert(1321142004022994845681377299801403567378503530250467610343381590909832171180, 6);
   node_index
       .insert(2592987851775965742543459319508348457290966253241455514226127639100457844774, 7);
   node_index
       .insert(2492755623019086109032247218615964389726368532160653497039005814484393419348, 8);
   node_index
       .insert(1323616023845704258113538348000047149470450086307731200728039607710316625916, 9);

   let atts = TreeEnsembleAttributes {
       nodes_falsenodeids,
       nodes_featureids,
       nodes_missing_value_tracks_true,
       nodes_modes,
       nodes_nodeids,
       nodes_treeids,
       nodes_truenodeids,
       nodes_values
   };

   let mut ensemble: TreeEnsemble<FP16x16> = TreeEnsemble {
       atts, tree_ids, root_index, node_index
   };

   let base_values: Option<Span<FP16x16>> = Option::None;

   let mut classifier: TreeEnsembleClassifier<FP16x16> = TreeEnsembleClassifier {
       ensemble,
       class_ids,
       class_nodeids,
       class_treeids,
       class_weights,
       classlabels,
       base_values,
       post_transform
   };

   let mut X: Tensor<FP16x16> = TensorTrait::new(
       array![3, 3].span(),
       array![
           FP16x16 { mag: 65536, sign: true },
           FP16x16 { mag: 52429, sign: true },
           FP16x16 { mag: 39322, sign: true },
           FP16x16 { mag: 26214, sign: true },
           FP16x16 { mag: 13107, sign: true },
           FP16x16 { mag: 0, sign: false },
           FP16x16 { mag: 13107, sign: false },
           FP16x16 { mag: 26214, sign: false },
           FP16x16 { mag: 39322, sign: false },
       ]
           .span()
   );

   (classifier, X)
}

fn test_tree_ensemble_classifier_multi_pt_softmax() -> (Span<usize>, MutMatrix::<FP16x16>) {
    let (mut classifier, X) = tree_ensemble_classifier_helper(POST_TRANSFORM::SOFTMAX);

    let (labels, scores) = TreeEnsembleClassifierTrait::predict(classifier, X);
    (labels, scores)
}   

>>> 
([0, 0, 1],
 [
   [0.545123, 0.217967, 0.23691],
   [0.416047, 0.284965, 0.298988],
   [0.322535, 0.366664, 0.310801],
  ])      
PreviousTree Ensemble ClassifierNextTree Ensemble Regressor

Last updated 1 year ago

Was this helpful?

🧱