阿里云-云小站(无限量代金券发放中)
【腾讯云】云服务器、云数据库、COS、CDN、短信等热卖云产品特惠抢购

解说pytorch中的model=model.to(device)

261次阅读
没有评论

共计 2868 个字符,预计需要花费 8 分钟才能阅读完成。

导读 这篇文章主要介绍了 pytorch 中的 model=model.to(device) 使用说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

这代表将模型加载到指定设备上。

其中,device=torch.device(“cpu”) 代表的使用 cpu,而 device=torch.device(“cuda”) 则代表的使用 GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用 model=model.to(device),将模型加载到相应的设备中。

将由 GPU 保存的模型加载到 CPU 上。

将 torch.load() 函数中的 map_location 参数设置为 torch.device(‘cpu’)

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由 GPU 保存的模型加载到 GPU 上。确保对输入的 tensors 调用 input = input.to(device) 方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由 CPU 保存的模型加载到 GPU 上。

确保对输入的 tensors 调用 input = input.to(device) 方法。map_location 是将模型加载到 GPU 上,model.to(torch.device(‘cuda’)) 是将模型参数加载为 CUDA 的 tensor。

最后保证使用.to(torch.device(‘cuda’)) 方法将需要使用的参数放入 CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch 中 model.to(device) 和 map_location=device 的区别

一、简介

在已训练并保存在 CPU 上的 GPU 上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是 model.to(device) 和 map_location=devicel 两个参数,简介一下两者的不同。

将 map_location 函数中的参数设置 torch.load() 为 cuda:device_id。这会将模型加载到给定的 GPU 设备。

调用 model.to(torch.device(‘cuda’)) 将模型的参数张量转换为 CUDA 张量,无论在 cpu 上训练还是 gpu 上训练,保存的模型参数都是参数张量不是 cuda 张量,因此,cpu 设备上不需要使用 torch.to(torch.device(“cpu”))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在 GPU 上,在 CPU 上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用 GPU 训练的 CPU 上加载模型时,请传递 torch.device(‘cpu’) 给 map_location 函数中的 torch.load() 参数,使用 map_location 参数将张量下面的存储器动态地重新映射到 CPU 设备。

2、保存在 GPU 上,在 GPU 上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在 GPU 上训练并保存在 GPU 上的模型时,只需将初始化 model 模型转换为 CUDA 优化模型即可 model.to(torch.device(‘cuda’))。

此外,请务必.to(torch.device(‘cuda’)) 在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用 my_tensor.to(device) 返回 my_tensorGPU 上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device(‘cuda’))

3、保存在 CPU,在 GPU 上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在 CPU 上的 GPU 上加载模型时,请将 map_location 函数中的参数设置 torch.load() 为 cuda:device_id。

这会将模型加载到给定的 GPU 设备。

接下来,请务必调用 model.to(torch.device(‘cuda’)) 将模型的参数张量转换为 CUDA 张量。

最后,确保.to(torch.device(‘cuda’)) 在所有模型输入上使用该 函数来为 CUDA 优化模型准备数据。

请注意,调用 my_tensor.to(device) 返回 my_tensorGPU 上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device(‘cuda’))

阿里云 2 核 2G 服务器 3M 带宽 61 元 1 年,有高配

腾讯云新客低至 82 元 / 年,老客户 99 元 / 年

代金券:在阿里云专用满减优惠券

正文完
星哥玩云-微信公众号
post-qrcode
 0
星锅
版权声明:本站原创文章,由 星锅 于2024-07-25发表,共计2868字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
【腾讯云】推广者专属福利,新客户无门槛领取总价值高达2860元代金券,每种代金券限量500张,先到先得。
阿里云-最新活动爆款每日限量供应
评论(没有评论)
验证码
【腾讯云】云服务器、云数据库、COS、CDN、短信等云产品特惠热卖中

星哥玩云

星哥玩云
星哥玩云
分享互联网知识
用户数
4
文章数
19348
评论数
4
阅读量
7798677
文章搜索
热门文章
开发者必备神器:阿里云 Qoder CLI 全面解析与上手指南

开发者必备神器:阿里云 Qoder CLI 全面解析与上手指南

开发者必备神器:阿里云 Qoder CLI 全面解析与上手指南 大家好,我是星哥。之前介绍了腾讯云的 Code...
星哥带你玩飞牛NAS-6:抖音视频同步工具,视频下载自动下载保存

星哥带你玩飞牛NAS-6:抖音视频同步工具,视频下载自动下载保存

星哥带你玩飞牛 NAS-6:抖音视频同步工具,视频下载自动下载保存 前言 各位玩 NAS 的朋友好,我是星哥!...
云服务器部署服务器面板1Panel:小白轻松构建Web服务与面板加固指南

云服务器部署服务器面板1Panel:小白轻松构建Web服务与面板加固指南

云服务器部署服务器面板 1Panel:小白轻松构建 Web 服务与面板加固指南 哈喽,我是星哥,经常有人问我不...
我把用了20年的360安全卫士卸载了

我把用了20年的360安全卫士卸载了

我把用了 20 年的 360 安全卫士卸载了 是的,正如标题你看到的。 原因 偷摸安装自家的软件 莫名其妙安装...
星哥带你玩飞牛NAS-3:安装飞牛NAS后的很有必要的操作

星哥带你玩飞牛NAS-3:安装飞牛NAS后的很有必要的操作

星哥带你玩飞牛 NAS-3:安装飞牛 NAS 后的很有必要的操作 前言 如果你已经有了飞牛 NAS 系统,之前...
阿里云CDN
阿里云CDN-提高用户访问的响应速度和成功率
随机文章
飞牛NAS中安装Navidrome音乐文件中文标签乱码问题解决、安装FntermX终端

飞牛NAS中安装Navidrome音乐文件中文标签乱码问题解决、安装FntermX终端

飞牛 NAS 中安装 Navidrome 音乐文件中文标签乱码问题解决、安装 FntermX 终端 问题背景 ...
优雅、强大、轻量开源的多服务器监控神器

优雅、强大、轻量开源的多服务器监控神器

优雅、强大、轻量开源的多服务器监控神器 在多台服务器同时运行的环境中,性能监控、状态告警、资源可视化 是运维人...
星哥带你玩飞牛NAS-8:有了NAS你可以干什么?软件汇总篇

星哥带你玩飞牛NAS-8:有了NAS你可以干什么?软件汇总篇

星哥带你玩飞牛 NAS-8:有了 NAS 你可以干什么?软件汇总篇 前言 哈喽各位玩友!我是是星哥,不少朋友私...
国产开源公众号AI知识库 Agent:突破未认证号限制,一键搞定自动回复,重构运营效率

国产开源公众号AI知识库 Agent:突破未认证号限制,一键搞定自动回复,重构运营效率

国产开源公众号 AI 知识库 Agent:突破未认证号限制,一键搞定自动回复,重构运营效率 大家好,我是星哥,...
星哥带你玩飞牛NAS硬件03:五盘位+N5105+双网口的成品NAS值得入手吗

星哥带你玩飞牛NAS硬件03:五盘位+N5105+双网口的成品NAS值得入手吗

星哥带你玩飞牛 NAS 硬件 03:五盘位 +N5105+ 双网口的成品 NAS 值得入手吗 前言 大家好,我...

免费图片视频管理工具让灵感库告别混乱

一言一句话
-「
手气不错
12.2K Star 爆火!开源免费的 FileConverter:右键一键搞定音视频 / 图片 / 文档转换,告别多工具切换

12.2K Star 爆火!开源免费的 FileConverter:右键一键搞定音视频 / 图片 / 文档转换,告别多工具切换

12.2K Star 爆火!开源免费的 FileConverter:右键一键搞定音视频 / 图片 / 文档转换...
星哥带你玩飞牛NAS-14:解锁公网自由!Lucky功能工具安装使用保姆级教程

星哥带你玩飞牛NAS-14:解锁公网自由!Lucky功能工具安装使用保姆级教程

星哥带你玩飞牛 NAS-14:解锁公网自由!Lucky 功能工具安装使用保姆级教程 作为 NAS 玩家,咱们最...
仅2MB大小!开源硬件监控工具:Win11 无缝适配,CPU、GPU、网速全维度掌控

仅2MB大小!开源硬件监控工具:Win11 无缝适配,CPU、GPU、网速全维度掌控

还在忍受动辄数百兆的“全家桶”监控软件?后台偷占资源、界面杂乱冗余,想查个 CPU 温度都要层层点选? 今天给...
星哥带你玩飞牛NAS-5:飞牛NAS中的Docker功能介绍

星哥带你玩飞牛NAS-5:飞牛NAS中的Docker功能介绍

星哥带你玩飞牛 NAS-5:飞牛 NAS 中的 Docker 功能介绍 大家好,我是星哥,今天给大家带来如何在...
星哥带你玩飞牛NAS-7:手把手教你免费内网穿透-Cloudflare tunnel

星哥带你玩飞牛NAS-7:手把手教你免费内网穿透-Cloudflare tunnel

星哥带你玩飞牛 NAS-7:手把手教你免费内网穿透 -Cloudflare tunnel 前言 大家好,我是星...