当前位置: 首页 > news >正文

徐州网站定制正能量网站地址污的

徐州网站定制,正能量网站地址污的,网站后台管理方便吗,信息流优化师工作总结1.数据集的路径,结构 dataset.py 目的: 输入:没有输入,路径是写死了的。 输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签&…

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.pathimport cv2
import torch
from torch.utils.data import Dataset
from torchvision import transformsclass DtataAndLabel(Dataset):def __init__(self,path='fruits',is_train=True):self.tran=transforms.Compose([transforms.ToTensor(),transforms.Resize(size=(88,88))])is_train='train' if True else 'test'self.data=[]path=os.path.join(path,is_train)print('path=',path)print(os.path.join(path, '*', '*'))img_paths=glob.glob(os.path.join(path,'*','*'))for img_path in img_paths:label=0 if img_path.split('\\')[-2]=='blueberry' else 1self.data.append((img_path,label))def __getitem__(self, idx):#每一张图片返回一个img_tensor,one_hotimg_path,label =self.data[idx]img=cv2.imread(img_path)# img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img_tensor=self.tran(img)img_tensor=img_tensor/255img_tensor=torch.flatten(img_tensor)one_hot=torch.zeros(2)one_hot[label]=1return img_tensor,one_hotdef __len__(self):return len(self.data)if __name__ == '__main__':# 测试data=DtataAndLabel()print(data[1][0].shape)print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nnclass Net(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(88*88*3, 800),nn.ReLU(),nn.Linear(800, 500),nn.ReLU(),nn.Linear(500, 800),nn.ReLU(),nn.Linear(800, 200),nn.ReLU(),nn.Linear(200, 2),)self.softmax=nn.Softmax(dim=1)def forward(self,x):x=self.model(x)x=self.softmax(x)return x
if __name__ == '__main__':net=Net()#测试一下x=torch.randn(1,100*100)out=net(x)print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():def __init__(self):self.writer = SummaryWriter("logs")self.train_data=DtataAndLabel(is_train=True)self.test_data=DtataAndLabel(is_train=False)#使用加载器分批加载self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)#初始化网络#损失函数#优化器net=Net()self.net=netself.loss=nn.MSELoss()self.opt=torch.optim.Adam(net.parameters(),lr=0.001)self.min_loss=100.0self.weight_path='weight/best.pt'def train(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)self.opt.zero_grad()loss.backward()self.opt.step()sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.train_loader)avg_acc = sum_acc / len(self.train_loader)print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)def test(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.test_loader)avg_acc = sum_acc / len(self.test_loader)print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)if avg_loss<self.min_loss:self.min_loss=min(self.min_loss,avg_loss)torch.save(self.net.state_dict(), self.weight_path)def run(self):for epo in range(100):self.train(epo)self.test(epo)if __name__ == '__main__':trainer=TrainAndTest()trainer.run()

精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

http://www.yayakq.cn/news/801323/

相关文章:

  • mvc网站建设的实验报告株洲网站开发公司
  • 网站空间200m连花清瘟为什么不能随便吃
  • 君临天下游戏网站开发者个人主页设计html代码
  • 网站开发成本图书馆网站建设需求方案
  • 设计网站推荐素材网站扫码员在哪个网站可以做
  • 用dw如何做网站首页开发公司成本部职责
  • 无锡网站制作推广公司网站文章来源seo
  • 哈尔滨网站制作哪儿好薇温州网站维护工作
  • 文字做图网站做网站需要懂那些软件
  • 网站灰色做网站有哪些语言
  • 徐州 网站 备案 哪个公司做的好网站优化什么意思
  • 杭州职工业能力建设网站福州网站建设网络公司排名
  • 目前最新的网站后台架构技术综述成全视频在线看
  • 建云购网站网站突然在百度消失了
  • 南京高端网站定制工业和信息化部发短信
  • 企业网站建设与管理试题简述网站开发设计流程图
  • 怎么做网站优化为什么要用wordpress
  • 学校网站建设报价市场推广
  • 百度竞价网站源码外贸网站建设定制
  • 如何建设教师网上授课网站广州市网站制作服务公司
  • 综合网站模板广西住房城乡建设领域
  • 企业营销网站开发建设专家石景山网站seo优化排名
  • 网站是asp还是php网站先做前台还是后台
  • 影楼网站模板网站开发实训目的
  • 手机网站 分享按钮做暖dnf动态ufo网站
  • 帝国网站管理系统 数据库设计一个企业网站报价
  • 枣阳网站开发wordpress rclean
  • windows10前段网站建设平价建网站
  • 区块链网站开发专做女鞋的网站
  • 外贸汽车网站创维爱内购网站