网站的优点有哪些,网站设计网站类型,html首页,wordpress 申请表单文章目录1 概述2 Tensor的基本操作2.1 Tensor的初始化#xff08;1#xff09;通过数组创建#xff08;2#xff09;通过默认方法创建#xff08;3#xff09;通过其他的tensor创建#xff08;4#xff09;通过opencv::core::Mat创建2.2 Tensor的属性2.3 Tensor的运算1通过数组创建2通过默认方法创建3通过其他的tensor创建4通过opencv::core::Mat创建2.2 Tensor的属性2.3 Tensor的运算1改变device2获取值(indexing and slicing)3合并tensors4四则运算参考资料1 概述
在使用rust进行torch模型部署时不可避免地会用到tch-rs。但是tch-rs的文档太过简洁和没有一样网上的资料也少得可怜很多操作需要我们自己去试。这些内容虽然简单但是自己找起来很费时间。
这篇文章总结了如何使用tch-rs进行tensor的基本操作。讲述的内容参考了pytorch的tensor教程。
运行环境
[dependencies]
tch 0.7.0
opencv 0.632 Tensor的基本操作
用到的库
use std::iter;use opencv::prelude::*;
use opencv::core::{Mat, Scalar};
use opencv::core::{CV_8UC3};
use tch::IndexOp;
use tch::{Device, Tensor};2.1 Tensor的初始化
1通过数组创建
let t Tensor::of_slice::i32([1, 2, 3, 4, 5]);
t.print();
// vector也是一样的
let v vec![1,2,3];
let t Tensor::of_slice::i32(v);
t.print();
// 2d vector
let v vec![[1.5,2.0,3.9,4.4], [3.1,4.3,5.1,6.9]];
let v:Vecf32 v.iter().flat_map(|array| array.iter()).cloned().collect();
let data unsafe{std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * std::mem::size_of::f32())
};
let t Tensor::of_data_size(data, [2,4], tch::Kind::Float);
t.print();print的结果是 12345
[ CPUIntType{5} ]123
[ CPUIntType{3} ]1.5000 2.0000 3.9000 4.40003.1000 4.3000 5.1000 6.9000
[ CPUFloatType{2,4} ]2通过默认方法创建
let t Tensor::randn([2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t Tensor::ones([2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t Tensor::zeros([2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t Tensor::arange_start(0, 2 * 3, (tch::Kind::Float, Device::Cpu)).view([2, 3]);
t.print();print的结果是 1.0522 0.6981 0.92360.2324 -1.1048 -2.5820
[ CPUFloatType{2,3} ]1 1 11 1 1
[ CPUFloatType{2,3} ]0 0 00 0 0
[ CPUFloatType{2,3} ]0 1 23 4 5
[ CPUFloatType{2,3} ]3通过其他的tensor创建
let t Tensor::randn([2, 3], (tch::Kind::Float, Device::Cpu));
let t t.rand_like();
t.print();print的结果是 0.3376 0.1885 0.34150.5135 0.8321 0.4140
[ CPUFloatType{2,3} ]4通过opencv::core::Mat创建
这可以用在opencv读取图像后转为torch tensor。当然tch-rs本身也有各种读取图片的方式可见tch::vision::image。这里介绍两种方法一种通过tch::Tensor::f_of_blob一种通过tch::Tensor::of_data_size。
// 创建一个(row, col, channel)(2, 3, 3)(height, width, channel)的Mat
let mat Mat::new_rows_cols_with_default(2, 3, CV_8UC3, Scalar::from((3.0, 2.0, 1.0))
).unwrap();
// 获取mat的size这里的结果是[2, 3, 3]
let size: Vec_ mat.mat_size().iter().cloned().map(|dim| dim as i64).chain(iter::once(mat.channels() as i64)).collect();
// 获取每个dimension的stride这里的结果是[9, 3, 1]
let strides {let mut strides: Vec_ size.iter().rev().cloned().scan(1, |prev, dim| {let stride *prev;*prev * dim;Some(stride)}).collect();strides.reverse();strides
};
// 构建tensor
let t unsafe {let ptr mat.ptr(0).unwrap() as *const u8;tch::Tensor::f_of_blob(ptr, size, strides, tch::Kind::Uint8, tch::Device::Cpu).unwrap()
};
t.print();print的结果是
(1,.,.) 3 2 13 2 13 2 1(2,.,.) 3 2 13 2 13 2 1
[ CPUByteType{2,3,3} ]还有一种比较简洁的转换方法
let mut mat Mat::new_rows_cols_with_default(2, 3, CV_8UC3, Scalar::from((3.0, 2.0, 1.0))
).unwrap();
let h mat.size().unwrap().height;
let w mat.size().unwrap().width;
let data mat.data_bytes_mut().unwrap();
let t tch::Tensor::of_data_size(data, [h as i64, w as i64, 3], tch::Kind::Uint8);
t.print();print的结果也是
(1,.,.) 3 2 13 2 13 2 1(2,.,.) 3 2 13 2 13 2 1
[ CPUByteType{2,3,3} ]
test tensor_ops::init_ops ... ok2.2 Tensor的属性
用tch::Tensor的print()方法可打印出数据的所有属性但是想要获取到这些属性需要用其他的方法。
let t Tensor::randn([2, 3], (tch::Kind::Float, Device::Cpu));
println!(size of the tensor: {:?}, t.size());
println!(kind of the tensor: {:?}, t.kind());
println!(device on which the tensor is located: {:?}, t.device());打印的结果是
size of the tensor: [2, 3]
kind of the tensor: Float
device on which the tensor is located: Cpu2.3 Tensor的运算
1改变device
.to()和.to_device()这两个方法都可以。
let mut t Tensor::randn([2, 3], (tch::Kind::Float, Device::Cpu));
if tch::Cuda::is_available(){t t.to(Device::Cuda(0));println!(change device to {:?}, t.device());
}
t t.to_device(Device::Cpu);
println!(change device to {:?}, t.device());如果是有cuda且安装了cuda版本的tch-rs的话就会打印出
change device to Cuda(0)
change device to Cpu2获取值(indexing and slicing)
这个在tch-rs的例子中有很多详见tests/tensor_indexing.rs。这里列几种常用的。
通过.i()进行索引
let tensor Tensor::arange_start(0, 2 * 3, (tch::Kind::Float, Device::Cpu)).view([2, 3]);
println!(original tensor:);
tensor.print();
println!(tensor.i(0):);
tensor.i(0).print();
println!(tensor.i((1, 1)):);
tensor.i((1, 1)).print();
println!(tensor.i((.., 2)):);
tensor.i((.., 2)).print();
println!(tensor.i((.., -1)):);
tensor.i((.., -1)).print();
println!(tensor.i((.., [2, 0])):);
let index: [_] [2, 0];
tensor.i((.., index)).print();打印的结果是
original tensor:0 1 23 4 5
[ CPUFloatType{2,3} ]
tensor.i(0):012
[ CPUFloatType{3} ]
tensor.i((1, 1)):
4
[ CPUFloatType{} ]
tensor.i((.., 2)):25
[ CPUFloatType{2} ]
tensor.i((.., -1)):25
[ CPUFloatType{2} ]
tensor.i((.., [2, 0])):2 05 3
[ CPUFloatType{2,2} ]通过.index()进行索引
let tensor Tensor::arange(6, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
println!(original tensor:);
tensor.print();
let rows_select Tensor::of_slice([0i64, 1, 0]);
let column_select Tensor::of_slice([1i64, 2, 2]);
let selected tensor.index([Some(rows_select), Some(column_select)]);
println!(selecte by row and column:);
selected.print();打印的结果是
original tensor:0 1 23 4 5
[ CPULongType{2,3} ]
selecte by row and column:152
[ CPULongType{3} ]3合并tensors
Tensor::f_cat不会生成新的axis而Tensor::stack会生成新的axis。
let t1 Tensor::arange(6, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
let t2 Tensor::arange_start(6, 12, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
let tensor Tensor::f_cat([t1.copy(), t2.copy()], 1).unwrap();
println!(using Tensor::f_cat);
tensor.print();
let tensor Tensor::stack([t1.copy(), t2.copy()], 1);
println!(using Tensor::stack);
tensor.print();打印的结果是
using Tensor::f_cat0 1 2 6 7 83 4 5 9 10 11
[ CPULongType{2,6} ]
using Tensor::stack
(1,.,.) 0 1 26 7 8(2,.,.) 3 4 59 10 11
[ CPULongType{2,2,3} ]4四则运算
tch-rs对[, -, *, /]都进行了重载可以实现和标量的直接运算。涉及到dim的复杂运算可以用tensor来处理。下面以加法为例其他与f_add对应的分别是f_subf_mul和f_div。
let tensor Tensor::ones([2, 4, 3], (tch::Kind::Float, Device::Cpu));
tensor.print();
// add with scalar
let add_tensor tensor 0.5;
add_tensor.print();
// add with tensor
let add_tensor Tensor::of_slice::f32([1.0,2.0,3.0]).view((1,1,3));
let add_tensor tensor.f_add(add_tensor).unwrap();
add_tensor.print();打印的结果为
original tensor:
(1,.,.) 1 1 11 1 11 1 11 1 1(2,.,.) 1 1 11 1 11 1 11 1 1
[ CPUFloatType{2,4,3} ]
add with scalar:
(1,.,.) 1.5000 1.5000 1.50001.5000 1.5000 1.50001.5000 1.5000 1.50001.5000 1.5000 1.5000(2,.,.) 1.5000 1.5000 1.50001.5000 1.5000 1.50001.5000 1.5000 1.50001.5000 1.5000 1.5000
[ CPUFloatType{2,4,3} ]
add with tensor:
(1,.,.) 2 3 42 3 42 3 42 3 4(2,.,.) 2 3 42 3 42 3 42 3 4
[ CPUFloatType{2,4,3} ]参考资料
[1] https://github.com/LaurentMazare/tch-rs [2] https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html#