Files
model-yolo-person-classify/train.py
T
2026-06-29 11:21:57 +09:00

212 lines
7.7 KiB
Python

# !/usr/local/bin/python3
import os
import time
import argparse
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from datafolder.folder import Train_Dataset
from net import get_model
######################################################################
# Settings
# --------
use_gpu = True
dataset_dict = {
'market' : 'Market-1501',
'duke' : 'DukeMTMC-reID',
}
######################################################################
# Argument
# --------
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--data-path', default='/path/to/dataset', type=str, help='path to the dataset')
parser.add_argument('--dataset', default='market', type=str, help='dataset: market, duke')
parser.add_argument('--backbone', default='resnet50', type=str, help='backbone: resnet50, resnet34, resnet18, densenet121')
parser.add_argument('--batch-size', default=32, type=int, help='batch size')
parser.add_argument('--num-epoch', default=60, type=int, help='num of epoch')
parser.add_argument('--num-workers', default=2, type=int, help='num_workers')
parser.add_argument('--use-id', action='store_true', help='use identity loss')
parser.add_argument('--lamba', default=1.0, type=float, help='weight of id loss')
args = parser.parse_args()
assert args.dataset in ['market', 'duke']
assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
dataset_name = dataset_dict[args.dataset]
model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone)
data_dir = args.data_path
model_dir = os.path.join('./checkpoints', args.dataset, model_name)
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
######################################################################
# Function
# --------
def save_network(network, epoch_label):
save_filename = 'net_%s.pth'% epoch_label
save_path = os.path.join(model_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if use_gpu:
network.cuda()
print('Save model to {}'.format(save_path))
######################################################################
# Draw Curve
#-----------
x_epoch = []
y_loss = {} # loss history
y_loss['train'] = []
y_loss['val'] = []
y_err = {}
y_err['train'] = []
y_err['val'] = []
fig = plt.figure()
ax0 = fig.add_subplot(121, title="loss")
ax1 = fig.add_subplot(122, title="top1err")
def draw_curve(current_epoch):
x_epoch.append(current_epoch)
ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
if current_epoch == 0:
ax0.legend()
ax1.legend()
fig.savefig( os.path.join(model_dir, 'train.jpg'))
######################################################################
# DataLoader
# ---------
image_datasets = {}
image_datasets['train'] = Train_Dataset(data_dir, dataset_name=dataset_name, train_val='train')
image_datasets['val'] = Train_Dataset(data_dir, dataset_name=dataset_name, train_val='query')
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers, drop_last=True)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# images, indices, labels, ids, cams, names = next(iter(dataloaders['train']))
num_label = image_datasets['train'].num_label()
num_id = image_datasets['train'].num_id()
labels_list = image_datasets['train'].labels()
######################################################################
# Model and Optimizer
# ------------------
model = get_model(model_name, num_label, args.use_id, num_id=num_id)
if use_gpu:
model = model.cuda()
# loss
criterion_bce = nn.BCELoss()
criterion_ce = nn.CrossEntropyLoss()
# optimizer
ignored_params = list(map(id, model.features.parameters()))
classifier_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.SGD([
{'params': model.features.parameters(), 'lr': 0.01},
{'params': classifier_params, 'lr': 0.1},
], momentum=0.9, weight_decay=5e-4, nesterov=True)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
######################################################################
# Training the model
# ------------------
def train_model(model, optimizer, scheduler, num_epochs):
since = time.time()
for epoch in range(1, num_epochs+1):
print('Epoch {}/{}'.format(epoch, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train(True) # Set model to training mode
else:
model.train(False) # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for count, (images, indices, labels, ids, cams, names) in enumerate(dataloaders[phase]):
# get the inputs
labels = labels.float()
if use_gpu:
images = images.cuda()
labels = labels.cuda()
indices = indices.cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward
if not args.use_id:
pred_label = model(images)
total_loss = criterion_bce(pred_label, labels)
else:
pred_label, pred_id = model(images)
label_loss = criterion_bce(pred_label, labels)
id_loss = criterion_ce(pred_id, indices)
total_loss = label_loss + args.lamba * id_loss
# backward + optimize only if in training phase
if phase == 'train':
total_loss.backward()
optimizer.step()
preds = torch.gt(pred_label, torch.ones_like(pred_label)/2 )
# statistics
running_loss += total_loss.item()
running_corrects += torch.sum(preds == labels.byte()).item() / num_label
if count % 100 == 0:
if not args.use_id:
print('step: ({}/{}) | label loss: {:.4f}'.format(
count*args.batch_size, dataset_sizes[phase], total_loss.item()))
else:
print('step: ({}/{}) | label loss: {:.4f} | id loss: {:.4f}'.format(
count*args.batch_size, dataset_sizes[phase], label_loss.item(), id_loss.item()))
epoch_loss = running_loss / len(dataloaders[phase])
epoch_acc = running_corrects / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
y_loss[phase].append(epoch_loss)
y_err[phase].append(1.0-epoch_acc)
# deep copy the model
if phase == 'val':
last_model_wts = model.state_dict()
if epoch % 10 == 0:
save_network(model, epoch)
draw_curve(epoch)
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
# load best model weights
model.load_state_dict(last_model_wts)
save_network(model, 'last')
######################################################################
# Main
# -----
train_model(model, optimizer, exp_lr_scheduler, num_epochs=args.num_epoch)