first commit
This commit is contained in:
@@ -0,0 +1,211 @@
|
||||
# !/usr/local/bin/python3
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import matplotlib
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datafolder.folder import Train_Dataset
|
||||
from net import get_model
|
||||
|
||||
######################################################################
|
||||
# Settings
|
||||
# --------
|
||||
use_gpu = True
|
||||
dataset_dict = {
|
||||
'market' : 'Market-1501',
|
||||
'duke' : 'DukeMTMC-reID',
|
||||
}
|
||||
|
||||
|
||||
######################################################################
|
||||
# Argument
|
||||
# --------
|
||||
parser = argparse.ArgumentParser(description='Training')
|
||||
parser.add_argument('--data-path', default='/path/to/dataset', type=str, help='path to the dataset')
|
||||
parser.add_argument('--dataset', default='market', type=str, help='dataset: market, duke')
|
||||
parser.add_argument('--backbone', default='resnet50', type=str, help='backbone: resnet50, resnet34, resnet18, densenet121')
|
||||
parser.add_argument('--batch-size', default=32, type=int, help='batch size')
|
||||
parser.add_argument('--num-epoch', default=60, type=int, help='num of epoch')
|
||||
parser.add_argument('--num-workers', default=2, type=int, help='num_workers')
|
||||
parser.add_argument('--use-id', action='store_true', help='use identity loss')
|
||||
parser.add_argument('--lamba', default=1.0, type=float, help='weight of id loss')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dataset in ['market', 'duke']
|
||||
assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
|
||||
|
||||
dataset_name = dataset_dict[args.dataset]
|
||||
model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone)
|
||||
data_dir = args.data_path
|
||||
model_dir = os.path.join('./checkpoints', args.dataset, model_name)
|
||||
|
||||
if not os.path.isdir(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
|
||||
|
||||
######################################################################
|
||||
# Function
|
||||
# --------
|
||||
def save_network(network, epoch_label):
|
||||
save_filename = 'net_%s.pth'% epoch_label
|
||||
save_path = os.path.join(model_dir, save_filename)
|
||||
torch.save(network.cpu().state_dict(), save_path)
|
||||
if use_gpu:
|
||||
network.cuda()
|
||||
print('Save model to {}'.format(save_path))
|
||||
|
||||
|
||||
######################################################################
|
||||
# Draw Curve
|
||||
#-----------
|
||||
x_epoch = []
|
||||
y_loss = {} # loss history
|
||||
y_loss['train'] = []
|
||||
y_loss['val'] = []
|
||||
y_err = {}
|
||||
y_err['train'] = []
|
||||
y_err['val'] = []
|
||||
|
||||
fig = plt.figure()
|
||||
ax0 = fig.add_subplot(121, title="loss")
|
||||
ax1 = fig.add_subplot(122, title="top1err")
|
||||
def draw_curve(current_epoch):
|
||||
x_epoch.append(current_epoch)
|
||||
ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
|
||||
ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
|
||||
ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
|
||||
ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
|
||||
if current_epoch == 0:
|
||||
ax0.legend()
|
||||
ax1.legend()
|
||||
fig.savefig( os.path.join(model_dir, 'train.jpg'))
|
||||
|
||||
|
||||
######################################################################
|
||||
# DataLoader
|
||||
# ---------
|
||||
image_datasets = {}
|
||||
image_datasets['train'] = Train_Dataset(data_dir, dataset_name=dataset_name, train_val='train')
|
||||
image_datasets['val'] = Train_Dataset(data_dir, dataset_name=dataset_name, train_val='query')
|
||||
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
|
||||
shuffle=True, num_workers=args.num_workers, drop_last=True)
|
||||
for x in ['train', 'val']}
|
||||
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
|
||||
|
||||
# images, indices, labels, ids, cams, names = next(iter(dataloaders['train']))
|
||||
|
||||
num_label = image_datasets['train'].num_label()
|
||||
num_id = image_datasets['train'].num_id()
|
||||
labels_list = image_datasets['train'].labels()
|
||||
|
||||
|
||||
######################################################################
|
||||
# Model and Optimizer
|
||||
# ------------------
|
||||
model = get_model(model_name, num_label, args.use_id, num_id=num_id)
|
||||
if use_gpu:
|
||||
model = model.cuda()
|
||||
|
||||
# loss
|
||||
criterion_bce = nn.BCELoss()
|
||||
criterion_ce = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
ignored_params = list(map(id, model.features.parameters()))
|
||||
classifier_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
|
||||
optimizer = torch.optim.SGD([
|
||||
{'params': model.features.parameters(), 'lr': 0.01},
|
||||
{'params': classifier_params, 'lr': 0.1},
|
||||
], momentum=0.9, weight_decay=5e-4, nesterov=True)
|
||||
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
|
||||
|
||||
|
||||
######################################################################
|
||||
# Training the model
|
||||
# ------------------
|
||||
def train_model(model, optimizer, scheduler, num_epochs):
|
||||
since = time.time()
|
||||
|
||||
for epoch in range(1, num_epochs+1):
|
||||
print('Epoch {}/{}'.format(epoch, num_epochs))
|
||||
print('-' * 10)
|
||||
|
||||
# Each epoch has a training and validation phase
|
||||
for phase in ['train', 'val']:
|
||||
if phase == 'train':
|
||||
scheduler.step()
|
||||
model.train(True) # Set model to training mode
|
||||
else:
|
||||
model.train(False) # Set model to evaluate mode
|
||||
|
||||
running_loss = 0.0
|
||||
running_corrects = 0
|
||||
|
||||
# Iterate over data.
|
||||
for count, (images, indices, labels, ids, cams, names) in enumerate(dataloaders[phase]):
|
||||
# get the inputs
|
||||
labels = labels.float()
|
||||
if use_gpu:
|
||||
images = images.cuda()
|
||||
labels = labels.cuda()
|
||||
indices = indices.cuda()
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward
|
||||
if not args.use_id:
|
||||
pred_label = model(images)
|
||||
total_loss = criterion_bce(pred_label, labels)
|
||||
else:
|
||||
pred_label, pred_id = model(images)
|
||||
label_loss = criterion_bce(pred_label, labels)
|
||||
id_loss = criterion_ce(pred_id, indices)
|
||||
total_loss = label_loss + args.lamba * id_loss
|
||||
|
||||
# backward + optimize only if in training phase
|
||||
if phase == 'train':
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
preds = torch.gt(pred_label, torch.ones_like(pred_label)/2 )
|
||||
# statistics
|
||||
running_loss += total_loss.item()
|
||||
running_corrects += torch.sum(preds == labels.byte()).item() / num_label
|
||||
if count % 100 == 0:
|
||||
if not args.use_id:
|
||||
print('step: ({}/{}) | label loss: {:.4f}'.format(
|
||||
count*args.batch_size, dataset_sizes[phase], total_loss.item()))
|
||||
else:
|
||||
print('step: ({}/{}) | label loss: {:.4f} | id loss: {:.4f}'.format(
|
||||
count*args.batch_size, dataset_sizes[phase], label_loss.item(), id_loss.item()))
|
||||
|
||||
epoch_loss = running_loss / len(dataloaders[phase])
|
||||
epoch_acc = running_corrects / dataset_sizes[phase]
|
||||
|
||||
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
|
||||
y_loss[phase].append(epoch_loss)
|
||||
y_err[phase].append(1.0-epoch_acc)
|
||||
# deep copy the model
|
||||
if phase == 'val':
|
||||
last_model_wts = model.state_dict()
|
||||
if epoch % 10 == 0:
|
||||
save_network(model, epoch)
|
||||
draw_curve(epoch)
|
||||
|
||||
time_elapsed = time.time() - since
|
||||
print('Training complete in {:.0f}m {:.0f}s'.format(
|
||||
time_elapsed // 60, time_elapsed % 60))
|
||||
|
||||
# load best model weights
|
||||
model.load_state_dict(last_model_wts)
|
||||
save_network(model, 'last')
|
||||
|
||||
|
||||
######################################################################
|
||||
# Main
|
||||
# -----
|
||||
train_model(model, optimizer, exp_lr_scheduler, num_epochs=args.num_epoch)
|
||||
Reference in New Issue
Block a user