当前位置: 首页 > news >正文

wordpress做游戏网站梧州网站设计公司

wordpress做游戏网站,梧州网站设计公司,seo实战培训费用,山西网络公司公司关于Checkpoints的内容在教程2里已经有了详细的说明,在本节,需要用它来利用模型进行预测 加载checkpoint并预测 使用模型进行预测的最简单方法是使用LightningModule中的load_from_checkpoint加载权重。 model LitModel.load_from_checkpoint("b…

关于Checkpoints的内容在教程2里已经有了详细的说明,在本节,需要用它来利用模型进行预测

加载checkpoint并预测

使用模型进行预测的最简单方法是使用LightningModule中的load_from_checkpoint加载权重。

model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)with torch.no_grad():y_hat = model(x)

predict_step方法

加载检查点并进行预测仍然会在预测阶段的epoch留下许多boilerplate,LightningModule中的预测步骤删除了这个boilerplate 。

class MyModel(LightningModule):def predict_step(self, batch, batch_idx, dataloader_idx=0):return self(batch)

并将任何dataloader传递给Lightning Trainer

data_loader = DataLoader(...)
model = MyModel()
trainer = Trainer()
predictions = trainer.predict(model, data_loader)

预测逻辑

当需要向数据添加复杂的预处理或后处理时,使用predict_step方法。例如,这里我们使用Monte Carlo Dropout 进行预测

class LitMCdropoutModel(pl.LightningModule):def __init__(self, model, mc_iteration):super().__init__()self.model = modelself.dropout = nn.Dropout()self.mc_iteration = mc_iterationdef predict_step(self, batch, batch_idx):# enable Monte Carlo Dropoutself.dropout.train()# take average of `self.mc_iteration` iterationspred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]pred = torch.vstack(pred).mean(dim=0)return pred

启用分布式推理

通过使用Lightning中的predict_step,可以使用BasePredictionWriter进行分布式推理。

import torch
from lightning.pytorch.callbacks import BasePredictionWriterclass CustomWriter(BasePredictionWriter):def __init__(self, output_dir, write_interval):super().__init__(write_interval)self.output_dir = output_dirdef write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):# 在'output_dir'中创建N (num进程)个文件,每个文件都包含对其各自rank的预测torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))# 可以保存'batch_indices',以便从预测数据中获取有关数据索引的信息torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))# 可以设置writer_interval="batch"
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)

也可以加载保存的checkpoint,把它当作一个普通的torch.nn.Module来使用。可以提取所有的torch.nn.Module,并在训练后使用LightningModule保存的checkpoint加载权重。建议从LightningModule的init和forward方法中复制明确的实现。

class Encoder(nn.Module):...class Decoder(nn.Module):...class AutoEncoderProd(nn.Module):def __init__(self):super().__init__()self.encoder = Encoder()self.decoder = Decoder()def forward(self, x):return self.encoder(x)class AutoEncoderSystem(LightningModule):def __init__(self):super().__init__()self.auto_encoder = AutoEncoderProd()def forward(self, x):return self.auto_encoder.encoder(x)def training_step(self, batch, batch_idx):x, y = batchy_hat = self.auto_encoder.encoder(x)y_hat = self.auto_encoder.decoder(y_hat)loss = ...return loss# 训练
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")# 创建PyTorch模型并加载checkpoint权重
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]# 恢复超参数
model = AutoEncoderProd(**hyper_parameters)model_weights = checkpoint["state_dict"]# 通过 dropping `auto_encoder.` 更新key值
for key in list(model_weights):model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)with torch.no_grad():y_hat = model(x)
http://www.yayakq.cn/news/733903/

相关文章:

  • 织梦系统网站wordpress短信验证码
  • 网络科技公司 网站建设哪个网站可以做笔译兼职
  • 广东东莞邮编新闻类的网站如何做优化
  • 网站开发需要投入多少时间山东网站seo公司
  • 千山科技做网站好不好定制家居软件app哪个好
  • 财政局网站建设方案有没有免费的推广平台
  • 做网站网站犯法吗优质的邵阳网站建设
  • 怎样设计自己的网站通常做网站要多久
  • 网站seo系统做试玩网站
  • 做美食网站的图片大全wordpress增加导航栏
  • 婚庆公司网站建设docin什么 wordpress
  • 用ps做商城网站好做吗在哪个网站可以学做甜点
  • 官方网站举例给特宝网站商家网址怎样做
  • 网站建设人员配备温州seo教程
  • 用php开发wap网站重庆最大本地论坛
  • 国外网站模板下载网站建设项目合同
  • 网站域名地址公司为什么要建立网站
  • 滁州seo网站推广网站换服务器百度不收录
  • 杭州cms建站模板网站模型怎么做
  • 北京首钢建设有限公司网站模板网站的弊端在哪
  • 网站建设与管理实践收获网站建设上的新闻
  • 小型深圳网站定制开发投票网页怎么制作
  • 服饰类电商网站建设策划火车头采集器wordpress3发布模块
  • 湖南网站优化公司云一网站设计
  • 鞍山网站建设制作技术太差 不想干程序员
  • 毕设做网站企业官方网站需要备案吗
  • 太原建站模板centos wordpress环境
  • 网站制作过程合理的步骤是胶州网站优化价格
  • 免费h5网站制作平台东莞清洁服务网站建设
  • 做外贸现在一般都通过哪些网站外贸建站推广哪家好