97 lines
2.8 KiB
Python
97 lines
2.8 KiB
Python
import os
|
|
import json
|
|
import torch
|
|
import argparse
|
|
from PIL import Image
|
|
from torchvision import transforms as T
|
|
from net import get_model
|
|
|
|
|
|
######################################################################
|
|
# Settings
|
|
# ---------
|
|
dataset_dict = {
|
|
'market' : 'Market-1501',
|
|
'duke' : 'DukeMTMC-reID',
|
|
}
|
|
num_cls_dict = { 'market':30, 'duke':23 }
|
|
num_ids_dict = { 'market':751, 'duke':702 }
|
|
|
|
transforms = T.Compose([
|
|
T.Resize(size=(288, 144)),
|
|
T.ToTensor(),
|
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
|
|
######################################################################
|
|
# Argument
|
|
# ---------
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('image_path', help='Path to test image')
|
|
parser.add_argument('--dataset', default='market', type=str, help='dataset')
|
|
parser.add_argument('--backbone', default='resnet50', type=str, help='model')
|
|
parser.add_argument('--use-id', action='store_true', help='use identity loss')
|
|
args = parser.parse_args()
|
|
|
|
assert args.dataset in ['market', 'duke']
|
|
assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
|
|
|
|
model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone)
|
|
num_label, num_id = num_cls_dict[args.dataset], num_ids_dict[args.dataset]
|
|
|
|
|
|
######################################################################
|
|
# Model and Data
|
|
# ---------
|
|
def load_network(network):
|
|
save_path = os.path.join('./checkpoints', args.dataset, model_name, 'net_last.pth')
|
|
network.load_state_dict(torch.load(save_path))
|
|
print('Resume model from {}'.format(save_path))
|
|
return network
|
|
|
|
def load_image(path):
|
|
src = Image.open(path)
|
|
src = transforms(src)
|
|
src = src.unsqueeze(dim=0)
|
|
return src
|
|
|
|
|
|
model = get_model(model_name, num_label, use_id=args.use_id, num_id=num_id)
|
|
model = load_network(model)
|
|
model.eval()
|
|
|
|
src = load_image(args.image_path)
|
|
|
|
######################################################################
|
|
# Inference
|
|
# ---------
|
|
class predict_decoder(object):
|
|
|
|
def __init__(self, dataset):
|
|
with open('./doc/label.json', 'r') as f:
|
|
self.label_list = json.load(f)[dataset]
|
|
with open('./doc/attribute.json', 'r') as f:
|
|
self.attribute_dict = json.load(f)[dataset]
|
|
self.dataset = dataset
|
|
self.num_label = len(self.label_list)
|
|
|
|
def decode(self, pred):
|
|
pred = pred.squeeze(dim=0)
|
|
for idx in range(self.num_label):
|
|
name, chooce = self.attribute_dict[self.label_list[idx]]
|
|
if chooce[pred[idx]]:
|
|
print('{}: {}'.format(name, chooce[pred[idx]]))
|
|
|
|
|
|
if not args.use_id:
|
|
out = model.forward(src)
|
|
else:
|
|
out, _ = model.forward(src)
|
|
|
|
pred = torch.gt(out, torch.ones_like(out)/2 ) # threshold=0.5
|
|
|
|
Dec = predict_decoder(args.dataset)
|
|
Dec.decode(pred)
|
|
|