first commit
This commit is contained in:
@@ -0,0 +1,200 @@
|
||||
import os
|
||||
import argparse
|
||||
import scipy.io
|
||||
import torch
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||
# from sklearn.exceptions import UndefinedMetricWarning
|
||||
from datafolder.folder import Test_Dataset
|
||||
from net import get_model
|
||||
|
||||
|
||||
######################################################################
|
||||
# Settings
|
||||
# ---------
|
||||
use_gpu = True
|
||||
dataset_dict = {
|
||||
'market' : 'Market-1501',
|
||||
'duke' : 'DukeMTMC-reID',
|
||||
}
|
||||
num_cls_dict = { 'market':30, 'duke':23 }
|
||||
num_ids_dict = { 'market':751, 'duke':702 }
|
||||
|
||||
|
||||
######################################################################
|
||||
# Argument
|
||||
# ---------
|
||||
parser = argparse.ArgumentParser(description='Testing')
|
||||
parser.add_argument('--data-path', default='/home/xxx/reid/', type=str, help='path to the dataset')
|
||||
parser.add_argument('--dataset', default='market', type=str, help='dataset')
|
||||
parser.add_argument('--backbone', default='resnet50', type=str, help='model')
|
||||
parser.add_argument('--batch-size', default=50, type=int, help='batch size')
|
||||
parser.add_argument('--num-epoch', default=60, type=int, help='num of epoch')
|
||||
parser.add_argument('--num-workers', default=2, type=int, help='num_workers')
|
||||
parser.add_argument('--which-epoch',default='last', type=str, help='0,1,2,3...or last')
|
||||
parser.add_argument('--print-table',action='store_true', help='print results with table format')
|
||||
parser.add_argument('--use-id', action='store_true', help='use identity loss')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dataset in ['market', 'duke']
|
||||
assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
|
||||
|
||||
dataset_name = dataset_dict[args.dataset]
|
||||
model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone)
|
||||
data_dir = args.data_path
|
||||
model_dir = os.path.join('./checkpoints', args.dataset, model_name)
|
||||
result_dir = os.path.join('./result', args.dataset, model_name)
|
||||
|
||||
if not os.path.isdir(result_dir):
|
||||
os.makedirs(result_dir)
|
||||
if not os.path.isdir(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
|
||||
|
||||
######################################################################
|
||||
# Function
|
||||
# ---------
|
||||
def load_network(network):
|
||||
save_path = os.path.join(model_dir,'net_%s.pth'%args.which_epoch)
|
||||
network.load_state_dict(torch.load(save_path))
|
||||
print('Resume model from {}'.format(save_path))
|
||||
return network
|
||||
|
||||
|
||||
def get_dataloader():
|
||||
image_datasets = {}
|
||||
image_datasets['gallery'] = Test_Dataset(data_dir, dataset_name=dataset_name, query_gallery='gallery')
|
||||
image_datasets['query'] = Test_Dataset(data_dir, dataset_name=dataset_name, query_gallery='query')
|
||||
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
|
||||
shuffle=True, num_workers=args.num_workers)
|
||||
for x in ['gallery', 'query']}
|
||||
return dataloaders
|
||||
|
||||
|
||||
def check_metric_vaild(y_pred, y_true):
|
||||
if y_true.min() == y_true.max() == 0: # precision
|
||||
return False
|
||||
if y_pred.min() == y_pred.max() == 0: # recall
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
######################################################################
|
||||
# Load Data
|
||||
# ---------
|
||||
# Note that we only perform evaluation on gallery set.
|
||||
test_loader = get_dataloader()['gallery']
|
||||
|
||||
attribute_list = test_loader.dataset.labels()
|
||||
num_label = len(attribute_list)
|
||||
num_sample = len(test_loader.dataset)
|
||||
num_id = num_ids_dict[args.dataset]
|
||||
|
||||
|
||||
######################################################################
|
||||
# Model
|
||||
# ---------
|
||||
model = get_model(model_name, num_label, use_id=args.use_id, num_id=num_id)
|
||||
model = load_network(model)
|
||||
if use_gpu:
|
||||
model = model.cuda()
|
||||
model.train(False) # Set model to evaluate mode
|
||||
|
||||
|
||||
######################################################################
|
||||
# Testing
|
||||
# ---------
|
||||
preds_tensor = np.empty(shape=[0, num_label], dtype=np.byte) # shape = (num_sample, num_label)
|
||||
labels_tensor = np.empty(shape=[0, num_label], dtype=np.byte) # shape = (num_sample, num_label)
|
||||
|
||||
# Iterate over data.
|
||||
with torch.no_grad():
|
||||
for count, (images, labels, ids, file_name) in enumerate(test_loader):
|
||||
# move input to GPU
|
||||
if use_gpu:
|
||||
images = images.cuda()
|
||||
# forward
|
||||
if not args.use_id:
|
||||
pred_label = model(images)
|
||||
else:
|
||||
pred_label, _ = model(images)
|
||||
|
||||
preds = torch.gt(pred_label, torch.ones_like(pred_label)/2)
|
||||
# transform to numpy format
|
||||
labels = labels.cpu().numpy()
|
||||
preds = preds.cpu().numpy()
|
||||
# append
|
||||
preds_tensor = np.append(preds_tensor, preds, axis=0)
|
||||
labels_tensor = np.append(labels_tensor, labels, axis=0)
|
||||
# print info
|
||||
if count*args.batch_size % 5000 == 0:
|
||||
print('Step: {}/{}'.format(count*args.batch_size, num_sample))
|
||||
|
||||
# Evaluation.
|
||||
accuracy_list = []
|
||||
precision_list = []
|
||||
recall_list = []
|
||||
f1_score_list = []
|
||||
average_precision = 0.0
|
||||
average_recall = 0.0
|
||||
average_f1score = 0.0
|
||||
valid_count = 0
|
||||
for i, name in enumerate(attribute_list):
|
||||
y_true, y_pred = labels_tensor[:, i], preds_tensor[:, i]
|
||||
accuracy_list.append(accuracy_score(y_true, y_pred))
|
||||
if check_metric_vaild(y_pred, y_true): # exclude ill-defined cases
|
||||
precision_list.append(precision_score(y_true, y_pred, average='binary'))
|
||||
recall_list.append(recall_score(y_true, y_pred, average='binary'))
|
||||
f1_score_list.append(f1_score(y_true, y_pred, average='binary'))
|
||||
average_precision += precision_list[-1]
|
||||
average_recall += recall_list[-1]
|
||||
average_f1score += f1_score_list[-1]
|
||||
valid_count += 1
|
||||
else:
|
||||
precision_list.append(-1)
|
||||
recall_list.append(-1)
|
||||
f1_score_list.append(-1)
|
||||
|
||||
average_acc = np.mean(accuracy_list)
|
||||
average_precision = average_precision / valid_count
|
||||
average_recall = average_recall / valid_count
|
||||
average_f1score = average_f1score / valid_count
|
||||
|
||||
|
||||
######################################################################
|
||||
# Print
|
||||
# ---------
|
||||
print("\n"
|
||||
"The Precision, Recall and F-score are ignored for some ill-defined cases."
|
||||
"\n")
|
||||
|
||||
if args.print_table:
|
||||
from prettytable import PrettyTable
|
||||
table = PrettyTable(['attribute', 'accuracy', 'precision', 'recall', 'f1 score'])
|
||||
for i, name in enumerate(attribute_list):
|
||||
table.add_row([name,
|
||||
'%.3f' % accuracy_list[i],
|
||||
'%.3f' % precision_list[i] if precision_list[i] >= 0.0 else '-',
|
||||
'%.3f' % recall_list[i] if recall_list[i] >= 0.0 else '-',
|
||||
'%.3f' % f1_score_list[i] if f1_score_list[i] >= 0.0 else '-',
|
||||
])
|
||||
print(table)
|
||||
|
||||
|
||||
print('Average accuracy: {:.4f}'.format(average_acc))
|
||||
# print('Average precision: {:.4f}'.format(average_precision))
|
||||
# print('Average recall: {:.4f}'.format(average_recall))
|
||||
print('Average f1 score: {:.4f}'.format(average_f1score))
|
||||
|
||||
# Save results.
|
||||
result = {
|
||||
'average_acc' : average_acc,
|
||||
'average_f1score' : average_f1score,
|
||||
'accuracy_list' : accuracy_list,
|
||||
'precision_list' : precision_list,
|
||||
'recall_list' : recall_list,
|
||||
'f1_score_list' : f1_score_list,
|
||||
}
|
||||
scipy.io.savemat(os.path.join(result_dir, 'acc.mat'), result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user