Pytorch 训练框架,日志管理,可视化

时间:2020-03-24
本文章向大家介绍Pytorch 训练框架,日志管理,可视化,主要包括Pytorch 训练框架,日志管理,可视化使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

torchfurnace

torchfurnace 是一个集快速训练模型,日志管理,模型checkpoints管理,tensorboard可视化, I/O 加速,模型大小统计于一身的工具包。
使用这个工具包可以快速构建一个深度学习训练,不需要自己写各种训练逻辑,对于已经定义好的模型也不需要修改,
可以说是拿来即用

使用: pip install torchfurnace

github: https://github.com/tianyu-su/torchfurnace

下面的例子是快速搭建训练,使用 VGG16 训练 CIFIAR10


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR
from torchfurnace import Engine, Parser

# define training process of your model
class VGGNetEngine(Engine):
    @staticmethod
    def _on_forward(training, model, inp, target, optimizer=None) -> dict:
        ret = {'loss': object, 'preds': object}
        output = model(inp)
        loss = F.cross_entropy(output, target)
        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        ret['loss'] = loss
        ret['preds'] = output
        return ret

    @staticmethod
    def _get_lr_scheduler(optim) -> list:
        return [MultiStepLR(optim, milestones=[150, 250, 350], gamma=0.1)]

def main():
    # define experiment name
    parser = Parser('TVGG16')
    args = parser.parse_args()
    experiment_name = '_'.join([args.dataset, args.exp_suffix])

    # Data
    ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = CIFAR10(root='data', train=True, download=True, transform=ts)
    testset = CIFAR10(root='data', train=False, download=True, transform=ts)

    # define model and optimizer
    net = torchvision.models.vgg16(pretrained=False, num_classes=10)
    net.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
    net.classifier = nn.Linear(512, 10)
    optimizer = torch.optim.Adam(net.parameters())

    # new engine instance
    eng = VGGNetEngine(parser).experiment_name(experiment_name)
    acc1 = eng.learning(net, optimizer, trainset, testset)
    print('Acc1:', acc1)

if __name__ == '__main__':
    import sys
    run_params='--dataset CIFAR10 -lr 0.1 -bs 128 -j 2 --epochs 400 --adjust_lr'
    sys.argv.extend(run_params.split())
    main()

原文地址:https://www.cnblogs.com/TianyuSu/p/12560193.html