1 import torch 2 import numpy as np 3 import torchvision 4 import torch.nn as nn 5 6 from torchvision import datasets,transforms,models 7 import matplotlib.pyplot as plt 8 import time 9 import os 10 import copy 11 print("Torchvision Version:",torchvision.__version__) 12 13 data_dir="./hymenoptera_data" 14 batch_size=32 15 input_size=224 16 model_name="resnet" 17 num_classes=2 18 num_epochs=15 19 feature_extract=True 20 data_transforms={ 21 "train":transforms.Compose([ 22 transforms.RandomResizedCrop(input_size), 23 transforms.RandomHorizontalFlip(), 24 transforms.ToTensor(), 25 transforms.Normalize([0.482,0.456,0.406],[0.229,0.224,0.225]) 26 ]), 27 "val":transforms.Compose([ 28 29 transforms.RandomResizedCrop(input_size), 30 transforms.RandomHorizontalFlip(), 31 transforms.ToTensor(), 32 transforms.Normalize([0.482, 0.456, 0.406], [0.229, 0.224, 0.225]) 33 ]), 34 } 35 image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) 36 for x in ["train",‘val‘]} 37 dataloader_dict={x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size, 38 shuffle=True)for x in [‘train‘,‘val‘]} 39 device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 inputs,labels=next(iter(dataloader_dict["train"])) 41 #print(inputs.shape)#一个batch 42 #print(labels) 43 44 45 #加载resent模型并修改全连接层 46 def set_parameter_requires_grad(model,feature_extract): 47 if feature_extract: 48 for param in model.parameters(): 49 param.requires_grad=False 50 51 def initialize_model(model_name,num_classes,feature_extract,use_pretrained=True): 52 if model_name=="resnet": 53 model_ft=models.resnet18(pretrained=use_pretrained) 54 set_parameter_requires_grad(model_ft,feature_extract) 55 num_ftrs=model_ft.fc.in_features 56 model_ft.fc=nn.Linear(num_ftrs,num_classes) 57 input_size=224 58 else: 59 print("model not implemented") 60 return None,None 61 62 return model_ft,input_size 63 model_ft,input_size=initialize_model(model_name,num_classes,feature_extract,use_pretrained=True) 64 #print(model_ft) 65 print(‘-‘*200) 66 67 68 def train_model(model, dataloaders, loss_fn, optimizer, num_epochs=5): 69 best_model_wts = copy.deepcopy(model.state_dict()) 70 best_acc = 0. 71 val_acc_history = [] 72 for epoch in range(num_epochs): 73 for phase in ["train", "val"]: 74 running_loss = 0. 75 running_corrects = 0. 76 if phase == "train": 77 model.train() 78 else: 79 model.eval() 80 81 for inputs, labels in dataloaders[phase]: 82 inputs, labels = inputs.to(device), labels.to(device) 83 84 with torch.autograd.set_grad_enabled(phase == "train"): 85 outputs = model(inputs) # bsize * 2 86 loss = loss_fn(outputs, labels) 87 88 preds = outputs.argmax(dim=1) 89 if phase == "train": 90 optimizer.zero_grad() 91 loss.backward() 92 optimizer.step() 93 running_loss += loss.item() * inputs.size(0) 94 running_corrects += torch.sum(preds.view(-1) == labels.view(-1)).item() 95 96 epoch_loss = running_loss / len(dataloaders[phase].dataset) 97 epoch_acc = running_corrects / len(dataloaders[phase].dataset) 98 99 print("Phase {} loss: {}, acc: {}".format(phase, epoch_loss, epoch_acc)) 100 101 if phase == "val" and epoch_acc > best_acc: 102 best_acc = epoch_acc 103 best_model_wts = copy.deepcopy(model.state_dict()) 104 if phase == "val": 105 val_acc_history.append(epoch_acc) 106 model.load_state_dict(best_model_wts) 107 return model, val_acc_history 108 109 110 model_ft = model_ft.to(device) 111 optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, 112 model_ft.parameters()), lr=0.001, momentum=0.9) 113 loss_fn = nn.CrossEntropyLoss() 114 _, ohist = train_model(model_ft, dataloader_dict, loss_fn, optimizer, num_epochs=num_epochs) 115 116 117 118 plt.title("Validation Accuracy vs. Number of Training Epochs") 119 plt.xlabel("Training Epochs") 120 plt.ylabel("Validation Accuracy") 121 plt.plot(range(1,num_epochs+1),ohist,label="Pretrained") 122 plt.ylim((0,1.)) 123 plt.xticks(np.arange(1, num_epochs+1, 1.0)) 124 plt.legend() 125 plt.show()
原文:https://www.cnblogs.com/-xuewuzhijing-/p/12987581.html