如何使用C++调用Pytorch模型进行推理测试:使用libtorch库

如何使用C++调用Pytorch模型进行推理测试:使用libtorch库

目录

      • 如何使用C++调用Pytorch模型进行推理测试:使用libtorch库
        • 一、环境准备
          • 1,linux:以ubuntu 22.04系统为例
            • 1. 准备CUDA和CUDNN
            • 2. 准备C++环境
            • 3, 下载libtorch文件
            • 4, 编写测试libtorch是否安装成功
          • 2, windows: 以win10系统为例
            • 1, 准备CUDA和CUDNN
            • 2,准备C++编译环境
            • 3,下载安装libtorch
            • 4. 注意事项
          • 二、C++代码封装Pytorch模型测试:以resnet-18分类为例
          • 1, 安装opencv用于读取图像
          • 2,用python导出训练好的pytorch模型
          • 3,编写C++代码测试

一、环境准备
1,linux:以ubuntu 22.04系统为例
1. 准备CUDA和CUDNN

有两种方式配置cuda和cudnn,一种是在系统环境安装,可以参考:深度学习环境配置——ubuntu安装CUDA与CUDNN

还有一种是在conda虚拟环境使用cudatoolkit-dev包,具体可以参考:Installing-and-Test-PyTorch-C-API-on-Ubuntu-with-GPU-enabled

我选择的方式是在系统环境安装cuda12.1和cudnn8.9.2。

可使用如下命令查看是否安装成功:

NVCC -V
cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2

image-20240625103837610

2. 准备C++环境

安装gcc, cmake和GLIBC,用apt install即可

可使用如下命令是否查看是否安装成功:

gcc --version
cmake --version
ldd --version

image-20240625103749911

3, 下载libtorch文件

去pytoch官网https://pytorch.org/下载即可:

image-20240625103946244

可使用如下命令下载并解压:

wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.3.1%2Bcu121.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.3.1+cu121.zip

将libtorch路径配置到path变量:

vim ~/.bashrc

最后一行加入:

export LD_LIBRARY_PATH=/path/to/libtorch/lib:$LD_LIBRARY_PATH

注意将/path/to/libtorch替换为实际的path,我这里是/mnt/data1/zq/libtorch

查看是否成功:

source ~/.bashrc
echo $LD_LIBRARY_PATH

image-20240625110447696

4, 编写测试libtorch是否安装成功

创建main.cpp文件,内容如下:

#include <torch/torch.h>
#include <iostream>

int main() {
    if (torch::cuda::is_available()) {
        std::cout << "CUDA is available! Running on GPU." << std::endl;
        // 创建一个随机张量并将其移到GPU上
        torch::Tensor tensor_gpu = torch::rand({2, 3}).cuda();
        std::cout << "Tensor on GPU:\n" << tensor_gpu << std::endl;
    } else {
        std::cout << "CUDA not available! Running on CPU." << std::endl;
        // 创建一个随机张量并保持在CPU上
        torch::Tensor tensor_cpu = torch::rand({2, 3});
        std::cout << "Tensor on CPU:\n" << tensor_cpu << std::endl;
    }
    return 0;
}

编译和运行

创建CMakeLists.txt文件,内容如下:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(test_project)

# Setting the C++ standard to C++17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# If additional compiler flags are needed
add_compile_options(-Wall -Wextra -pedantic)

# Setting the location of LibTorch
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)

# Specify the name of the executable and the corresponding source file
add_executable(test_project main.cpp)

# Linking LibTorch libraries
target_link_libraries(test_project "${TORCH_LIBRARIES}")

# Set the output directory for the executable
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin)

/path/to/libtorch替换为实际的path

编译并测试:

mkdir build
cd build
cmake ..
make 

编译完成之后,应该会出现一个bin目录,其中有一个test_project文件,直接运行即可看到输出。

image-20240625111448917

出现CUDAFloatType说明,libtorch的GPU版本安装成功。

2, windows: 以win10系统为例
1, 准备CUDA和CUDNN

可参考:Windows10下CUDA与cuDNN的安装

2,准备C++编译环境

这一步需要配置cmake, mingw。可参考:Windows 配置 C/C++ 开发环境

建议直接安装Visual Studio这个IDE,可参考:Windows libtorch C++部署GPU版

3,下载安装libtorch

参考这个视频:

win10系统上LibTorch的安装和使用(cuda10.1版本)

一个很水的LibTorch教程(1)

4. 注意事项

windows环境我没有做测试,不保证一定可以成功。linux环境是亲自测试的,保证可以复现

二、C++代码封装Pytorch模型测试:以resnet-18分类为例
1, 安装opencv用于读取图像

需要使用opencv来读取图像数据,可通过如下命令安装:

sudo apt install libopencv-dev
dpkg -l | grep libopencv # 查看是否安装成功
2,用python导出训练好的pytorch模型

在将PyTorch模型应用于C++环境之前,需要将其转换为TorchScript。这可以通过两种方式实现:tracingscripting。可以通过如下代码导出训练好的ResNet-18模型:

import torch
import torchvision

# 加载预训练的模型
model = torchvision.models.resnet18(pretrained=True)

# 将模型设置为评估模式
model.eval()

# 创建一个示例输入
example_input = torch.rand(1, 3, 224, 224)  # 模型输入的大小

# 使用tracing导出模型
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("resnet18.pt")
3,编写C++代码测试

创建main.cpp文件,内容如下:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <filesystem>

// Function to transform image to tensor
torch::Tensor transform_image(const cv::Mat& image) {
    cv::Mat img_transformed;
    cv::cvtColor(image, img_transformed, cv::COLOR_BGR2RGB);
    cv::resize(img_transformed, img_transformed, cv::Size(224, 224));
    img_transformed.convertTo(img_transformed, CV_32FC3, 1.0/255);
    auto img_tensor = torch::from_blob(img_transformed.data, {img_transformed.rows, img_transformed.cols, 3}, torch::kFloat);
    img_tensor = img_tensor.permute({2, 0, 1});
    img_tensor = torch::data::transforms::Normalize<torch::Tensor>({0.485, 0.456, 0.406}, {0.229, 0.224, 0.225})(img_tensor);
    img_tensor = img_tensor.unsqueeze(0);
    return img_tensor;
}

// Load the model and classify an image
void classify_image(const std::string& model_path, const std::string& image_path) {
    // Load the model
    torch::jit::script::Module model = torch::jit::load(model_path);
    model.eval(); // Switch to evaluation mode

    // Load and transform the image
    cv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR);
    if (image.empty()) {
        std::cerr << "Could not read the image: " << image_path << std::endl;
        return;
    }
    torch::Tensor tensor_image = transform_image(image);

    // Perform inference
    torch::Tensor output = model.forward({tensor_image}).toTensor();
    int64_t pred = output.argmax(1).item<int64_t>();

    std::cout << "The image is classified as class index: " << pred << std::endl;
}

int main(int argc, char* argv[]) {
    std::string model_path = "resnet18.pt"; // Default model path
    std::string image_path = "default_image.jpg"; // Default image path
	
    // 从命令行接受两个参数, 分别作为model_path和image_path
    if (argc >= 3) {
        model_path = argv[1];
        image_path = argv[2];
    } else {
        std::cout << "Using default model and image paths." << std::endl;
    }

    classify_image(model_path, image_path);
    return 0;
}

创建CMakeLists.txt,内容如下:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(ImageClassification)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# 设置LibTorch的位置, /path/to/libtorch替换为实际路径
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)

find_package(OpenCV REQUIRED)

add_executable(ImageClassification main.cpp)
target_link_libraries(ImageClassification "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")

编译并运行:

mkdir build && cd build
cmake ..
make

在build目录下会出现ImageClassification这个可执行文件,直接运行传入model_path和image_path即可。

image-20240625114911739

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/773899.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

uniapp中实现跳转链接到游览器(安卓-h5)

uniapp中实现跳转链接到游览器&#xff08;安卓-h5&#xff09; 项目中需要做到跳转到外部链接&#xff0c;网上找了很多都不是很符合自己的要求&#xff0c;需要编译成app后是跳转到游览器打开链接&#xff0c;编译成web是在新窗口打开链接。实现的代码如下&#xff1a; 效果&…

“谋士三国”诸葛亮的锦囊妙计 - 策略模式

“当代码如三国&#xff0c;智慧如孔明&#xff0c;何愁天下设计不归一统&#xff1f;” 乱世之中&#xff0c;英雄辈出。三国的战场上&#xff0c;不仅刀光剑影&#xff0c;更有智慧的较量。诸葛亮的锦囊妙计&#xff0c;不正是今日软件设计中策略模式的完美写照吗&#xff1…

Python酷库之旅-第三方库Pandas(003)

目录 一、用法精讲 4、pandas.read_csv函数 4-1、语法 4-2、参数 4-3、功能 4-4、返回值 4-5、说明 4-6、用法 4-6-1、创建csv文件 4-6-2、代码示例 4-6-3、结果输出 二、推荐阅读 1、Python筑基之旅 2、Python函数之旅 3、Python算法之旅 4、Python魔法之旅 …

五.核心动画 - 图层的变换(平移,缩放,旋转,3D变化)

引言 在上一篇博客中&#xff0c;我们研究了一些视觉效果&#xff0c;在本篇博客中我们将要来讨论一下图层的旋转&#xff0c;平移&#xff0c;缩放&#xff0c;以及可以将扁平物体转换成三维空间对象的CATransform3D。 图层变换 图层的仿射变换 在视图中有一个transform属…

海外发稿: 秘鲁-区块链新闻媒体通稿宣发

秘鲁媒体单发 随着全球化的不断深入&#xff0c;海外发稿已经成为众多企业宣传推广的重要方式之一。而在海外发稿的选择中&#xff0c;秘鲁媒体的地位尤为重要。秘鲁作为南美洲的重要国家之一&#xff0c;拥有众多知名媒体平台&#xff0c;包括diariodelcusco、serperuano、el…

全网视频下载之IDM下载安装,软破解

全网视频下载之IDM下载安装&#xff0c;软破解 介绍![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/c94f612f7a8845c8a649f74f6b18fd70.png)下载安装配置浏览器Google浏览器Ddge浏览器 界面如何下载不破解如何重复使用总结 介绍 今天给大家分享一个更加简便的全网视…

nftables(1)基本原理

简介 nftables 是 Linux 内核中用于数据包分类的现代框架&#xff0c;用来替代旧的 iptables&#xff08;包括 ip6tables, arptables, ebtables 等&#xff0c;统称为 xtables&#xff09;架构。nftables 提供了更强大、更灵活以及更易于管理的规则集配置方式&#xff0c;使得…

【matlab】智能优化算法——求解目标函数

智能优化算法在求解目标函数方面发挥着重要作用&#xff0c;它通过迭代、筛选等方法来寻找目标函数的最优值&#xff08;极值&#xff09;。以下是关于智能优化算法求解目标函数的详细介绍&#xff1a; 一、智能优化算法概述 智能优化算法是一种搜索算法&#xff0c;旨在通过…

0/1背包问题总结

文章目录 &#x1f347;什么是0/1背包问题&#xff1f;&#x1f348;例题&#x1f349;1.分割等和子集&#x1f349;2.目标和&#x1f349;3.最后一块石头的重量Ⅱ &#x1f34a;总结 博客主页&#xff1a;lyyyyrics &#x1f347;什么是0/1背包问题&#xff1f; 0/1背包问题是…

《简历宝典》02 - 如果你是HR,你会优先打开哪份简历?

现在的求职环境不必多说&#xff0c;其实我们大家都还是很清楚的。所以&#xff0c;在这个环境下&#xff0c;写一份优秀的简历&#xff0c;目的与作用也不必多说。那么&#xff0c;这一小节呢&#xff0c;我们先从简历这份文档的文档名开始说起。 目录 1 你觉得HR们刷简历的时…

【SVN的使用-源代码管理工具-命令行的使用 Objective-C语言】

一、接下来,我们来说一个终端的命令行的使用, 1.我们说,你的电脑里边呢,有终端, 在Mac里边,你想新建一个txt,应该怎么写,对,打开文本编辑, 打开这个东西,写点儿东西,然后保存一下,保存的时候,你还要去选择格式, 现在,如果我们用命令行,可以更方便一些, 2.首…

企业用私户发工资算不算偷税?

一般来说&#xff0c;给员工发工资都是用企业的对公账户去发&#xff0c;但是&#xff0c;有的企业会用私户去发工资&#xff0c;早前就有蜜雪冰城股东用私户给员工发奖金被税局稽查&#xff0c;最终补缴个税近800万的新闻&#xff0c;可见&#xff0c;私户发工资是具有很大风险…

上海时尚新品发布会,可以邀请哪些媒体

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 在上海举办时尚新品发布会时&#xff0c;可以邀请的媒体类型多样&#xff0c;以下是一些建议的媒体类型及其特点&#xff1a; 一、平面媒体 报纸&#xff1a; 《文汇报》&#xff1a;上…

底层软件 | 十分详细,为了学习设备树,我写了5w字笔记!

0、设备树是什么&#xff1f;1、DTS 1.1 dts简介1.2 dts例子 2、DTC&#xff08;Device Tree Compiler&#xff09;3、DTB&#xff08;Device Tree Blob&#xff09;4、绑定&#xff08;Binding&#xff09;5、Bootloader compatible属性 7、 #address-cells和#size-cells属性8…

Qt源码解析之QObject

省去大部分virtual和public方法后&#xff0c;Qobject主要剩下以下成员&#xff1a; //qobject.h class Q_CORE_EXPORT Qobject{Q_OBJECTQ_PROPERTY(QString objectName READ objectName WRITE setObjectName NOTIFY objectNameChanged)Q_DECLARE_PRIVATE(QObject) public:Q_I…

印章谁在管、谁用了、用在哪?契约锁让您打开手机一看便知

“印章都交给谁在管”、“哪些人能用”、“都有哪些业务在用”…这些既是管理者最关心的印章问题也是影响印章安全的关键要素。但是公司旗下分子公司那么多&#xff0c;各类公章、法人章、财务章、合同章一大堆&#xff0c;想“问”明白很难。 契约锁电子签及印控平台推出“印章…

OpenLayers使用2

接着上一篇https://blog.csdn.net/weixin_51416826/article/details/140161160?spm1001.2014.3001.5502 本篇主要内容是基于高德API逆向地址解析获取城市中心点&#xff0c;并且设置了输入框&#xff0c;可以输入城市执行飞行&#xff0c;同时基于高德API获取城市天气信息&am…

【漏洞复现】万户协同办公平台——反序列化

声明&#xff1a;本文档或演示材料仅供教育和教学目的使用&#xff0c;任何个人或组织使用本文档中的信息进行非法活动&#xff0c;均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 万户协同办公平台ezEIP是一个综合信息基础应用平台&#xff0c;…

Leaflet【六】绘制交互图形、测量、经纬度展示

本文主要探讨了如何利用leaflet-draw插件在地图上绘制图形&#xff0c;以及通过leaflet-measure测量距离和面积&#xff0c;并将经纬度绘制到地图上。首先&#xff0c;我们使用leaflet-draw插件&#xff0c;该插件提供了一种简单而直观的方式来绘制各种形状&#xff08;如点、线…

配置基于不同IP地址的虚拟主机

定义配置文件vhost.conf <directory /www> allowoverride none require all granted </directory> <virtualhost 192.168.209.136:80> documentroot /www servername 192.168.209.136 </virtualhost><virtualhost 192.168.209.138:80> document…