当前位置: 首页 > news >正文

手机网站开发服务网页关键词排名优化

手机网站开发服务,网页关键词排名优化,静态网站建设流程怎么写,安庆网站建设aqwzjs文章目录 1.保存、加载模型2.torch.nn.Module.state_dict()2.1基本使用2.2保存和加载状态字典 3.创建Checkpoint3.1基本使用3.2完整案例 1.保存、加载模型 torch.save()用于保存一个序列化对象到磁盘上,该序列化对象可以是任何类型的对象,包括模型、张量…

文章目录

  • 1.保存、加载模型
  • 2.torch.nn.Module.state_dict()
    • 2.1基本使用
    • 2.2保存和加载状态字典
  • 3.创建Checkpoint
    • 3.1基本使用
    • 3.2完整案例


1.保存、加载模型

  torch.save()用于保存一个序列化对象到磁盘上,该序列化对象可以是任何类型的对象,包括模型、张量和字典等(内部使用pickle模块实现对象的序列化)。数据会被保存为.pt.pth格式,可通过torch.load()从磁盘加载被保存的序列化对象,加载时会重新构造出原来的对象。
  torch.save()有两种保存模型的方式:

  • 1.保存整个模型(继承了torch.nn.Module的类),不推荐使用。
    • torch.load():利用pickle将保存的序列化对象反序列化,得到原始数据。可用于加载完整模型或状态字典。
#保存整个模型
torch.save(model, PATH)
#加载模型
model = torch.load(PATH)
  • 2.仅保存模型的参数(状态字典state_dict),推荐使用。
    • torch.nn.Module.load_state_dict():通过反序列化得到模型的state_dict()(状态字典)来加载模型,传入的参数是状态字典,而非.pt.pth文件。
#只保存模型参数
torch.save(model.state_dict(), PATH)
#加载模型
model=Model()
model.load_state_dict(torch.load(PATH))

  在实际使用中推荐第二种方式,第一种方式往往容易产生各种错误:

  • 设备错误。若在cuda:0上训练好一个模型并保存,则读取出来的模型也是默认在cuda:0上,如果训练过程的其他数据被放到了cuda:1上,那么就会发生错误:
RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1503966894950/work/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215

此时需要将其他其他数据都保存在cuda:0上,或加载模型时指定使用cuda:1

device = torch.device("cuda:1")
model = torch.load(PATH, map_location=device)
  • 版本错误:比如使用pytorch1.0训练并保存CNN模型,再用pytorch1.1读取模型,则会出现错误:
AttributeError: 'Conv2d' object has no attribute 'padding_mode'

此时只能通过获取该模型的参数来加载新的模型:

#加载模型参数
model_state = torch.load(model_path).state_dict()
#初始化新模型并加载参数
model = Model()
model.load_state_dict(model_state)

2.torch.nn.Module.state_dict()

2.1基本使用

  torch.nn.Module.state_dict()用于返回模型的状态字典,其中保存了模型的可学习参数。其中,只有可学习参数的层(卷积层、全连接层等)和注册缓冲区(batchnorm’s running_mean)才会作为模型参数保存(优化器也有状态字典,也可进行保存)。
【例子】

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 初始化模型
model = TheModelClass()# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])

  查看模型与优化器的状态字典:
在这里插入图片描述

2.2保存和加载状态字典

  通过torch.save()来保存模型的状态字典(state_dict),即只保存学习到的模型参数,并通过torch.nn.Module.load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为.pt.pth

#保存模型状态字典
PATH = './test_state_dict.pth'
torch.save(model.state_dict(), PATH)
#根据状态字典加载模型
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.eval()
#打印新模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())

  注意,模型推理之前,需要调用model.eval()函数将dropoutbatch normalization层设置为评估模式,否则会导致模型推理结果不一致。
在这里插入图片描述

3.创建Checkpoint

3.1基本使用

  模型检查点(checkpoint)是指模型训练过程中保存的模型状态,包括模型参数(权重与偏置)、优化器状态等其他相关的训练信息。通过保存检查点,可以实现在训练过程中定期保存模型的当前状态,以便在需要时恢复训练或用于模型评估和推理。模型检查点常见的保存信息如下:

  • 1.模型权重:模型的状态字典。
  • 2.优化器状态:优化器的状态字典。
  • 3.训练状态:当前的训练轮数(epoch)、批次(batch)等。
  • 4.其他数据:如学习率调度器的状态、自定义指标等。

例如:
【保存检查点】

#将模型参数和优化器状态的状态字典保存到检查点中
checkpoint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss.item(),'epoch':epoch
}#保存检查点
torch.save(checkpoint, 'checkpoint.pth')

【加载检查点】

# 加载检查点
checkpoint = torch.load('checkpoint.pth')# 恢复模型和优化器状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])# 恢复训练状态
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果是恢复训练,可以从保存的epoch继续
for epoch in range(epoch, num_epochs):# 继续训练

3.2完整案例

import torch
import torch.nn as nn
import torch.optim as optim# 假设有一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()# 训练循环
num_epochs = 100
for epoch in range(num_epochs):# 假设有输入x和目标yx = torch.randn(64, 10)y = torch.randn(64, 1)optimizer.zero_grad()output = model(x)loss = loss_fn(output, y)loss.backward()optimizer.step()# 每10个epoch保存一次检查点if epoch % 10 == 0:checkpoint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'epoch': epoch,'loss': loss.item()}torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')# 加载检查点并继续训练
checkpoint = torch.load('checkpoint_epoch_10.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']# 从第11个epoch开始继续训练
for epoch in range(start_epoch + 1, num_epochs):# 继续训练pass
http://www.yayakq.cn/news/306400/

相关文章:

  • 做推送的网站有哪些wordpress做门户
  • 做网站时需要注意什么自己名字怎么设计logo
  • 高端网站建设 选择磐石网络爱你社区
  • 请问聊城网站建设网页制作的公司为什么瓯北没有
  • 金山专业网站建设保定公司网站建设
  • 网站网站建设设计公司网站域名正在维护中
  • 徐州专业做网站顺企网宁波网站建设
  • 温州做网站整站优化知识营销案例有哪些
  • 深圳网站建设服务找哪家昆明室内设计公司排名
  • 建设网站哪个便宜菜鸟如何做网站
  • 本地做织梦网站企业软件项目管理系统
  • 免费的网站域名和空间wordpress二维码活码
  • 12389举报网站建设项目域名申请好了 怎么做网站
  • 猪八戒里面做网站骗子很多家具网站建设策划方案
  • 分类门户网站系统163网易企业邮箱格式
  • 做船公司网站asp网站开发工具神器
  • 英文专业的网站设计学校网站制作
  • wordpress 设置角色seo网址查询
  • 衡阳商城网站制作中国大连网站
  • 网站改版总结酒业网站模板下载
  • 做网站什么分类流量多wordpress原生html5播放器
  • 内外网网站栏目建设方案长春网站建设流程
  • 查看网站有多少空间好发信息网网站建设
  • 做网站会遇到哪些问题宁波seo怎么推广
  • 山西网站建设开发团队带登录网站模板
  • 桂林生活网官方网站什么时候网站建设
  • 做环保是跑还是网站卖网站建设需要哪些功能
  • 太原招聘网站开发网站界面设计稿
  • 做影视网站赚钱wordpress扁平主题
  • wordpress网站建小程序郑州建设局官网