commit e75f32c80db16bf27245edbef9c3ca7e2bf6c446 Author: Daekeun Date: Mon Jun 29 11:21:57 2026 +0900 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3b35c41 --- /dev/null +++ b/.gitignore @@ -0,0 +1,72 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +*.egg-info/ +.eggs/ +dist/ +build/ +*.egg +.venv/ +venv/ +env/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db +desktop.ini + +# Model weights & exports +# weights/ +# *.pt +# *.pth +# *.onnx +# *.engine +# *.weights + +# Images (inference input/output) +*_marked.jpg +*_marked.jpeg +*_marked.png +image.jpg +*.jpg +*.jpeg +*.png +*.bmp +*.webp +!image1.jpg +!image2.jpg +!image3.jpg + +# Ultralytics outputs +runs/ +output/ +outputs/ +results/ + +# ClearML +clearml.conf +*.log + +# Jupyter +.ipynb_checkpoints/ + +# Environment & secrets +.env +.env.* +*.pem + + +# etc +datafolder/ +doc/ +net/ +test_sample/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..f1bb08e --- /dev/null +++ b/README.md @@ -0,0 +1,181 @@ +# Person-Attribute-Recognition-MarketDuke +A simple baseline implemented in PyTorch for **pedestrian attribute recognition** task, evaluating on Market-1501-attribute and DukeMTMC-reID-attribute dataset. + +## Dataset +You can get [Market-1501-attribute](https://github.com/vana77/Market-1501_Attribute) and [DukeMTMC-reID-attribute](https://github.com/vana77/DukeMTMC-attribute) annotations from [here](https://github.com/vana77). Also you need to download Market-1501 and DukeMTMC-reID dataset. + +Then, create a folder named 'attribute' under your dataset path, and put corresponding annotations into the folder. + +For example,
+``` +├── dataset +│ ├── DukeMTMC-reID +│ ├── bounding_box_test +│ ├── bounding_box_train +│ ├── query +│ ├── attribute +│ ├── duke_attribute.mat +``` + +## Model +Trained model are provided. You may download it from [Google Drive](https://drive.google.com/drive/folders/1JTdjuEbxSLypnfUzVuuxLj1uSKAacfd0?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1bByCxZp9bSs8YYZPbuK21A) (提取码:jpks). + +You may download it and move `checkpoints` folder to your project's root directory. + +## Dependencies +* Python 3.5 +* PyTorch >= 0.4.1 +* torchvision >= 0.2.1 +* matplotlib, sklearn, prettytable (optional) + +## Usage +``` +python3 train.py --data-path ~/dataset --dataset [market | duke] --model resnet50 [--use-id] + +python3 test.py --data-path ~/dataset --dataset [market | duke] --model resnet50 [--print-table] + +python3 inference.py test_sample/test_market.jpg [--dataset market] [--model resnet50] +``` + +## Result + +We use **binary classification** settings (considered each attribute as an independent binary classification problem), and the classification threshold is **0.5**. + +***Note that the precision, recall and f1 score are denoted as '-' for some ill-defined cases.*** + +### Market-1501 gallery +``` ++------------+----------+-----------+--------+----------+ +| attribute | accuracy | precision | recall | f1 score | ++------------+----------+-----------+--------+----------+ +| young | 0.998 | 0.533 | 0.267 | 0.356 | +| teenager | 0.892 | 0.927 | 0.951 | 0.939 | +| adult | 0.895 | 0.582 | 0.450 | 0.508 | +| old | 0.992 | 0.037 | 0.012 | 0.019 | +| backpack | 0.883 | 0.828 | 0.672 | 0.742 | +| bag | 0.790 | 0.608 | 0.378 | 0.467 | +| handbag | 0.893 | 0.254 | 0.065 | 0.104 | +| clothes | 0.946 | 0.956 | 0.984 | 0.970 | +| down | 0.945 | 0.968 | 0.949 | 0.959 | +| up | 0.936 | 0.938 | 0.998 | 0.967 | +| hair | 0.877 | 0.871 | 0.773 | 0.819 | +| hat | 0.982 | 0.812 | 0.505 | 0.623 | +| gender | 0.919 | 0.947 | 0.864 | 0.903 | +| upblack | 0.954 | 0.859 | 0.790 | 0.823 | +| upwhite | 0.926 | 0.846 | 0.882 | 0.863 | +| upred | 0.974 | 0.904 | 0.840 | 0.871 | +| uppurple | 0.985 | 0.703 | 0.815 | 0.755 | +| upyellow | 0.976 | 0.895 | 0.836 | 0.865 | +| upgray | 0.909 | 0.852 | 0.391 | 0.537 | +| upblue | 0.946 | 0.868 | 0.420 | 0.566 | +| upgreen | 0.966 | 0.790 | 0.713 | 0.750 | +| downblack | 0.879 | 0.815 | 0.889 | 0.850 | +| downwhite | 0.956 | 0.608 | 0.550 | 0.578 | +| downpink | 0.989 | 0.795 | 0.782 | 0.788 | +| downpurple | 1.000 | - | - | - | +| downyellow | 0.999 | 0.000 | 0.000 | 0.000 | +| downgray | 0.878 | 0.756 | 0.443 | 0.559 | +| downblue | 0.861 | 0.762 | 0.446 | 0.563 | +| downgreen | 0.978 | 0.766 | 0.295 | 0.426 | +| downbrown | 0.958 | 0.754 | 0.590 | 0.662 | ++------------+----------+-----------+--------+----------+ +Average accuracy: 0.9361 +Average f1 score: 0.6492 +``` + +### DukeMTMC-ReID gallery +``` ++-----------+----------+-----------+--------+----------+ +| attribute | accuracy | precision | recall | f1 score | ++-----------+----------+-----------+--------+----------+ +| backpack | 0.829 | 0.794 | 0.926 | 0.855 | +| bag | 0.836 | 0.496 | 0.287 | 0.364 | +| handbag | 0.935 | 0.469 | 0.073 | 0.126 | +| boots | 0.905 | 0.784 | 0.791 | 0.787 | +| gender | 0.858 | 0.806 | 0.828 | 0.817 | +| hat | 0.898 | 0.883 | 0.680 | 0.768 | +| shoes | 0.916 | 0.756 | 0.414 | 0.535 | +| top | 0.893 | 0.590 | 0.381 | 0.463 | +| upblack | 0.821 | 0.827 | 0.903 | 0.864 | +| upwhite | 0.959 | 0.750 | 0.509 | 0.606 | +| upred | 0.973 | 0.745 | 0.649 | 0.694 | +| uppurple | 0.995 | 0.258 | 0.123 | 0.167 | +| upgray | 0.900 | 0.611 | 0.333 | 0.432 | +| upblue | 0.943 | 0.766 | 0.519 | 0.619 | +| upgreen | 0.975 | 0.463 | 0.403 | 0.431 | +| upbrown | 0.980 | 0.481 | 0.328 | 0.390 | +| downblack | 0.787 | 0.740 | 0.807 | 0.772 | +| downwhite | 0.945 | 0.771 | 0.395 | 0.522 | +| downred | 0.991 | 0.739 | 0.645 | 0.689 | +| downgray | 0.927 | 0.471 | 0.238 | 0.317 | +| downblue | 0.807 | 0.741 | 0.669 | 0.703 | +| downgreen | 0.997 | - | - | - | +| downbrown | 0.979 | 0.871 | 0.594 | 0.706 | ++-----------+----------+-----------+--------+----------+ +Average accuracy: 0.9152 +Average f1 score: 0.5739 +``` + +### Inference +``` +>> python inference.py test_sample/test_market.jpg --dataset market +age: teenager +carrying backpack: no +carrying bag: no +carrying handbag: no +type of lower-body clothing: dress +length of lower-body clothing: short +sleeve length: short sleeve +hair length: long hair +wearing hat: no +gender: female +color of upper-body clothing: white +color of lower-body clothing: white + +>> python inference.py test_sample/test_duke.jpg --dataset duke +carrying backpack: no +carrying bag: yes +carrying handbag: no +wearing boots: no +gender: male +wearing hat: no +color of shoes: dark +length of upper-body clothing: short upper body clothing +color of upper-body clothing: black +color of lower-body clothing: blue +``` + +## Update +*20-06-03: Added **identity loss** for joint optimization; Adjusted the learning rate for better performace.* + +*20-06-03: Updated **test.py**, settled the issue of ill-defined metrics.* + +*19-09-16: Updated **inference.py**, fixed the error caused by missing data-transform.* + +*19-09-06: Updated **test.py**, added **F1 score** for evaluating.* + +*19-09-03: Added **inference.py**, thanks @ViswanathaReddyGajjala.* + +*19-08-23: Released trained models.* + +*19-01-09: Fixed the error caused by an update of market and duke attribute dataset.* + +## FAQ + +### 1. Why attribute order in import_Market1501Attribute.py is different for train and test data? + +The label order in import_Market1501Attribute.py is consistent with the attribute order of the dataset. + +You can load market_attribute.mat in MATLAB and print "market_attribute.train" or "market_attribute.test" to obtain these orders. + +### 2. Why predictions in the Market-1501 dataset have 30 attributes instead of 27? + +This repo consider attribute prediction as multiple binary classification, but some attribute have more than two categories. + +For example, attribute 'age' in Market-1501 has four categories: young(1), teenager(2), adult(3), old(4). So it can be split into four attributes: 'young', 'teenager', 'adult' and 'old'. + +That's why preds of Market-1501 has 30 attributes. + +## Reference + +*[1] Lin Y, Zheng L, Zheng Z, et al. Improving person re-identification by attribute and identity learning[J]. Pattern Recognition, 2019.* diff --git a/checkpoints/duke/resnet50_nfc/net_last.pth b/checkpoints/duke/resnet50_nfc/net_last.pth new file mode 100644 index 0000000..fc85ea3 Binary files /dev/null and b/checkpoints/duke/resnet50_nfc/net_last.pth differ diff --git a/checkpoints/market/resnet50_nfc/net_last.pth b/checkpoints/market/resnet50_nfc/net_last.pth new file mode 100644 index 0000000..7456093 Binary files /dev/null and b/checkpoints/market/resnet50_nfc/net_last.pth differ diff --git a/image1.jpg b/image1.jpg new file mode 100644 index 0000000..d95348f Binary files /dev/null and b/image1.jpg differ diff --git a/image2.jpg b/image2.jpg new file mode 100644 index 0000000..920bb71 Binary files /dev/null and b/image2.jpg differ diff --git a/image3.jpg b/image3.jpg new file mode 100644 index 0000000..2254e82 Binary files /dev/null and b/image3.jpg differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..9b1f975 --- /dev/null +++ b/inference.py @@ -0,0 +1,96 @@ +import os +import json +import torch +import argparse +from PIL import Image +from torchvision import transforms as T +from net import get_model + + +###################################################################### +# Settings +# --------- +dataset_dict = { + 'market' : 'Market-1501', + 'duke' : 'DukeMTMC-reID', +} +num_cls_dict = { 'market':30, 'duke':23 } +num_ids_dict = { 'market':751, 'duke':702 } + +transforms = T.Compose([ + T.Resize(size=(288, 144)), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) + + +###################################################################### +# Argument +# --------- +parser = argparse.ArgumentParser() +parser.add_argument('image_path', help='Path to test image') +parser.add_argument('--dataset', default='market', type=str, help='dataset') +parser.add_argument('--backbone', default='resnet50', type=str, help='model') +parser.add_argument('--use-id', action='store_true', help='use identity loss') +args = parser.parse_args() + +assert args.dataset in ['market', 'duke'] +assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121'] + +model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone) +num_label, num_id = num_cls_dict[args.dataset], num_ids_dict[args.dataset] + + +###################################################################### +# Model and Data +# --------- +def load_network(network): + save_path = os.path.join('./checkpoints', args.dataset, model_name, 'net_last.pth') + network.load_state_dict(torch.load(save_path)) + print('Resume model from {}'.format(save_path)) + return network + +def load_image(path): + src = Image.open(path) + src = transforms(src) + src = src.unsqueeze(dim=0) + return src + + +model = get_model(model_name, num_label, use_id=args.use_id, num_id=num_id) +model = load_network(model) +model.eval() + +src = load_image(args.image_path) + +###################################################################### +# Inference +# --------- +class predict_decoder(object): + + def __init__(self, dataset): + with open('./doc/label.json', 'r') as f: + self.label_list = json.load(f)[dataset] + with open('./doc/attribute.json', 'r') as f: + self.attribute_dict = json.load(f)[dataset] + self.dataset = dataset + self.num_label = len(self.label_list) + + def decode(self, pred): + pred = pred.squeeze(dim=0) + for idx in range(self.num_label): + name, chooce = self.attribute_dict[self.label_list[idx]] + if chooce[pred[idx]]: + print('{}: {}'.format(name, chooce[pred[idx]])) + + +if not args.use_id: + out = model.forward(src) +else: + out, _ = model.forward(src) + +pred = torch.gt(out, torch.ones_like(out)/2 ) # threshold=0.5 + +Dec = predict_decoder(args.dataset) +Dec.decode(pred) + diff --git a/start.py b/start.py new file mode 100644 index 0000000..4e612b0 --- /dev/null +++ b/start.py @@ -0,0 +1,135 @@ +import os +import json +import torch +import argparse +import requests +from io import BytesIO +from PIL import Image +from torchvision import transforms as T +from net import get_model + +from clearml import Task # 1. ClearML 임포트 + +num_cls_dict = {'market': 30, 'duke': 23} +num_ids_dict = {'market': 751, 'duke': 702} + +transforms = T.Compose([ + T.Resize(size=(288, 144)), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) + +task = Task.init( + project_name="Person_Attribute_Recognition", + task_name="MarketDuke_Agent" +) + + +class PredictDecoder(object): + + def __init__(self, dataset): + with open('./doc/label.json', 'r') as f: + self.label_list = json.load(f)[dataset] + with open('./doc/attribute.json', 'r') as f: + self.attribute_dict = json.load(f)[dataset] + self.dataset = dataset + self.num_label = len(self.label_list) + + def decode(self, pred): + pred = pred.squeeze(dim=0) + results = [] + for idx in range(self.num_label): + name, choice = self.attribute_dict[self.label_list[idx]] + value = choice[pred[idx]] + if value: + results.append((name, value)) + return results + + +def load_network(network, dataset, model_name): + save_path = os.path.join('./checkpoints', dataset, model_name, 'net_last.pth') + network.load_state_dict(torch.load(save_path)) + print(f'[+] 모델 로드 완료: {save_path}') + return network + + +def preprocess_image(image): + src = transforms(image) + return src.unsqueeze(dim=0) + + +def load_image_from_url_or_path(image_source): + if image_source.startswith(('http://', 'https://')): + response = requests.get(image_source) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert('RGB') + + if not os.path.isfile(image_source): + print(f"Error: Image not found: {image_source}") + exit(1) + + return Image.open(image_source).convert('RGB') + + +def main(image, dataset='market', backbone='resnet50', use_id=False): + assert dataset in ['market', 'duke'] + assert backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121'] + + model_name = f'{backbone}_nfc_id' if use_id else f'{backbone}_nfc' + num_label = num_cls_dict[dataset] + num_id = num_ids_dict[dataset] + + model = get_model(model_name, num_label, use_id=use_id, num_id=num_id) + model = load_network(model, dataset, model_name) + model.eval() + + src = preprocess_image(image) + + with torch.no_grad(): + if not use_id: + out = model.forward(src) + else: + out, _ = model.forward(src) + + pred = torch.gt(out, torch.ones_like(out) / 2) + + decoder = PredictDecoder(dataset) + results = decoder.decode(pred) + + print("\n" + "=" * 50) + print(" Person 상세 속성 분석 결과 ") + print("=" * 50) + for name, value in results: + print(f'{name}: {value}') + print("=" * 50) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Person Attribute Recognition Agent") + parser.add_argument("--image_url", type=str) + parser.add_argument("--xywh", type=str) + + args = parser.parse_args() + image_url = args.image_url + if image_url is None: + print("Error: Image URL or path is required") + exit(1) + + oimg = load_image_from_url_or_path(image_url) + + xywh = args.xywh + if xywh is None: + print("Error: XYWH is required") + exit(1) + + xywh = xywh.split(",") + x = int(xywh[0]) + y = int(xywh[1]) + w = int(xywh[2]) + h = int(xywh[3]) + + cimg = oimg.crop((x, y, x + w, y + h)) + cimg.save("cropped_image.jpg") + print(f"[+] 타겟 영역 크롭 완료 -> 'cropped_image.jpg' 저장 완료.") + + main(cimg) diff --git a/test.py b/test.py new file mode 100644 index 0000000..72b0d83 --- /dev/null +++ b/test.py @@ -0,0 +1,200 @@ +import os +import argparse +import scipy.io +import torch +import numpy as np +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score +# from sklearn.exceptions import UndefinedMetricWarning +from datafolder.folder import Test_Dataset +from net import get_model + + +###################################################################### +# Settings +# --------- +use_gpu = True +dataset_dict = { + 'market' : 'Market-1501', + 'duke' : 'DukeMTMC-reID', +} +num_cls_dict = { 'market':30, 'duke':23 } +num_ids_dict = { 'market':751, 'duke':702 } + + +###################################################################### +# Argument +# --------- +parser = argparse.ArgumentParser(description='Testing') +parser.add_argument('--data-path', default='/home/xxx/reid/', type=str, help='path to the dataset') +parser.add_argument('--dataset', default='market', type=str, help='dataset') +parser.add_argument('--backbone', default='resnet50', type=str, help='model') +parser.add_argument('--batch-size', default=50, type=int, help='batch size') +parser.add_argument('--num-epoch', default=60, type=int, help='num of epoch') +parser.add_argument('--num-workers', default=2, type=int, help='num_workers') +parser.add_argument('--which-epoch',default='last', type=str, help='0,1,2,3...or last') +parser.add_argument('--print-table',action='store_true', help='print results with table format') +parser.add_argument('--use-id', action='store_true', help='use identity loss') +args = parser.parse_args() + +assert args.dataset in ['market', 'duke'] +assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121'] + +dataset_name = dataset_dict[args.dataset] +model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone) +data_dir = args.data_path +model_dir = os.path.join('./checkpoints', args.dataset, model_name) +result_dir = os.path.join('./result', args.dataset, model_name) + +if not os.path.isdir(result_dir): + os.makedirs(result_dir) +if not os.path.isdir(model_dir): + os.makedirs(model_dir) + + +###################################################################### +# Function +# --------- +def load_network(network): + save_path = os.path.join(model_dir,'net_%s.pth'%args.which_epoch) + network.load_state_dict(torch.load(save_path)) + print('Resume model from {}'.format(save_path)) + return network + + +def get_dataloader(): + image_datasets = {} + image_datasets['gallery'] = Test_Dataset(data_dir, dataset_name=dataset_name, query_gallery='gallery') + image_datasets['query'] = Test_Dataset(data_dir, dataset_name=dataset_name, query_gallery='query') + dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers) + for x in ['gallery', 'query']} + return dataloaders + + +def check_metric_vaild(y_pred, y_true): + if y_true.min() == y_true.max() == 0: # precision + return False + if y_pred.min() == y_pred.max() == 0: # recall + return False + return True + + +###################################################################### +# Load Data +# --------- +# Note that we only perform evaluation on gallery set. +test_loader = get_dataloader()['gallery'] + +attribute_list = test_loader.dataset.labels() +num_label = len(attribute_list) +num_sample = len(test_loader.dataset) +num_id = num_ids_dict[args.dataset] + + +###################################################################### +# Model +# --------- +model = get_model(model_name, num_label, use_id=args.use_id, num_id=num_id) +model = load_network(model) +if use_gpu: + model = model.cuda() +model.train(False) # Set model to evaluate mode + + +###################################################################### +# Testing +# --------- +preds_tensor = np.empty(shape=[0, num_label], dtype=np.byte) # shape = (num_sample, num_label) +labels_tensor = np.empty(shape=[0, num_label], dtype=np.byte) # shape = (num_sample, num_label) + +# Iterate over data. +with torch.no_grad(): + for count, (images, labels, ids, file_name) in enumerate(test_loader): + # move input to GPU + if use_gpu: + images = images.cuda() + # forward + if not args.use_id: + pred_label = model(images) + else: + pred_label, _ = model(images) + + preds = torch.gt(pred_label, torch.ones_like(pred_label)/2) + # transform to numpy format + labels = labels.cpu().numpy() + preds = preds.cpu().numpy() + # append + preds_tensor = np.append(preds_tensor, preds, axis=0) + labels_tensor = np.append(labels_tensor, labels, axis=0) + # print info + if count*args.batch_size % 5000 == 0: + print('Step: {}/{}'.format(count*args.batch_size, num_sample)) + +# Evaluation. +accuracy_list = [] +precision_list = [] +recall_list = [] +f1_score_list = [] +average_precision = 0.0 +average_recall = 0.0 +average_f1score = 0.0 +valid_count = 0 +for i, name in enumerate(attribute_list): + y_true, y_pred = labels_tensor[:, i], preds_tensor[:, i] + accuracy_list.append(accuracy_score(y_true, y_pred)) + if check_metric_vaild(y_pred, y_true): # exclude ill-defined cases + precision_list.append(precision_score(y_true, y_pred, average='binary')) + recall_list.append(recall_score(y_true, y_pred, average='binary')) + f1_score_list.append(f1_score(y_true, y_pred, average='binary')) + average_precision += precision_list[-1] + average_recall += recall_list[-1] + average_f1score += f1_score_list[-1] + valid_count += 1 + else: + precision_list.append(-1) + recall_list.append(-1) + f1_score_list.append(-1) + +average_acc = np.mean(accuracy_list) +average_precision = average_precision / valid_count +average_recall = average_recall / valid_count +average_f1score = average_f1score / valid_count + + +###################################################################### +# Print +# --------- +print("\n" + "The Precision, Recall and F-score are ignored for some ill-defined cases." + "\n") + +if args.print_table: + from prettytable import PrettyTable + table = PrettyTable(['attribute', 'accuracy', 'precision', 'recall', 'f1 score']) + for i, name in enumerate(attribute_list): + table.add_row([name, + '%.3f' % accuracy_list[i], + '%.3f' % precision_list[i] if precision_list[i] >= 0.0 else '-', + '%.3f' % recall_list[i] if recall_list[i] >= 0.0 else '-', + '%.3f' % f1_score_list[i] if f1_score_list[i] >= 0.0 else '-', + ]) + print(table) + + +print('Average accuracy: {:.4f}'.format(average_acc)) +# print('Average precision: {:.4f}'.format(average_precision)) +# print('Average recall: {:.4f}'.format(average_recall)) +print('Average f1 score: {:.4f}'.format(average_f1score)) + +# Save results. +result = { + 'average_acc' : average_acc, + 'average_f1score' : average_f1score, + 'accuracy_list' : accuracy_list, + 'precision_list' : precision_list, + 'recall_list' : recall_list, + 'f1_score_list' : f1_score_list, +} +scipy.io.savemat(os.path.join(result_dir, 'acc.mat'), result) + + diff --git a/train.py b/train.py new file mode 100644 index 0000000..2d9a848 --- /dev/null +++ b/train.py @@ -0,0 +1,211 @@ +# !/usr/local/bin/python3 +import os +import time +import argparse +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from datafolder.folder import Train_Dataset +from net import get_model + +###################################################################### +# Settings +# -------- +use_gpu = True +dataset_dict = { + 'market' : 'Market-1501', + 'duke' : 'DukeMTMC-reID', +} + + +###################################################################### +# Argument +# -------- +parser = argparse.ArgumentParser(description='Training') +parser.add_argument('--data-path', default='/path/to/dataset', type=str, help='path to the dataset') +parser.add_argument('--dataset', default='market', type=str, help='dataset: market, duke') +parser.add_argument('--backbone', default='resnet50', type=str, help='backbone: resnet50, resnet34, resnet18, densenet121') +parser.add_argument('--batch-size', default=32, type=int, help='batch size') +parser.add_argument('--num-epoch', default=60, type=int, help='num of epoch') +parser.add_argument('--num-workers', default=2, type=int, help='num_workers') +parser.add_argument('--use-id', action='store_true', help='use identity loss') +parser.add_argument('--lamba', default=1.0, type=float, help='weight of id loss') +args = parser.parse_args() + +assert args.dataset in ['market', 'duke'] +assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121'] + +dataset_name = dataset_dict[args.dataset] +model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone) +data_dir = args.data_path +model_dir = os.path.join('./checkpoints', args.dataset, model_name) + +if not os.path.isdir(model_dir): + os.makedirs(model_dir) + + +###################################################################### +# Function +# -------- +def save_network(network, epoch_label): + save_filename = 'net_%s.pth'% epoch_label + save_path = os.path.join(model_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if use_gpu: + network.cuda() + print('Save model to {}'.format(save_path)) + + +###################################################################### +# Draw Curve +#----------- +x_epoch = [] +y_loss = {} # loss history +y_loss['train'] = [] +y_loss['val'] = [] +y_err = {} +y_err['train'] = [] +y_err['val'] = [] + +fig = plt.figure() +ax0 = fig.add_subplot(121, title="loss") +ax1 = fig.add_subplot(122, title="top1err") +def draw_curve(current_epoch): + x_epoch.append(current_epoch) + ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train') + ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val') + ax1.plot(x_epoch, y_err['train'], 'bo-', label='train') + ax1.plot(x_epoch, y_err['val'], 'ro-', label='val') + if current_epoch == 0: + ax0.legend() + ax1.legend() + fig.savefig( os.path.join(model_dir, 'train.jpg')) + + +###################################################################### +# DataLoader +# --------- +image_datasets = {} +image_datasets['train'] = Train_Dataset(data_dir, dataset_name=dataset_name, train_val='train') +image_datasets['val'] = Train_Dataset(data_dir, dataset_name=dataset_name, train_val='query') +dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers, drop_last=True) + for x in ['train', 'val']} +dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} + +# images, indices, labels, ids, cams, names = next(iter(dataloaders['train'])) + +num_label = image_datasets['train'].num_label() +num_id = image_datasets['train'].num_id() +labels_list = image_datasets['train'].labels() + + +###################################################################### +# Model and Optimizer +# ------------------ +model = get_model(model_name, num_label, args.use_id, num_id=num_id) +if use_gpu: + model = model.cuda() + +# loss +criterion_bce = nn.BCELoss() +criterion_ce = nn.CrossEntropyLoss() + +# optimizer +ignored_params = list(map(id, model.features.parameters())) +classifier_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) +optimizer = torch.optim.SGD([ + {'params': model.features.parameters(), 'lr': 0.01}, + {'params': classifier_params, 'lr': 0.1}, + ], momentum=0.9, weight_decay=5e-4, nesterov=True) +exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) + + +###################################################################### +# Training the model +# ------------------ +def train_model(model, optimizer, scheduler, num_epochs): + since = time.time() + + for epoch in range(1, num_epochs+1): + print('Epoch {}/{}'.format(epoch, num_epochs)) + print('-' * 10) + + # Each epoch has a training and validation phase + for phase in ['train', 'val']: + if phase == 'train': + scheduler.step() + model.train(True) # Set model to training mode + else: + model.train(False) # Set model to evaluate mode + + running_loss = 0.0 + running_corrects = 0 + + # Iterate over data. + for count, (images, indices, labels, ids, cams, names) in enumerate(dataloaders[phase]): + # get the inputs + labels = labels.float() + if use_gpu: + images = images.cuda() + labels = labels.cuda() + indices = indices.cuda() + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + if not args.use_id: + pred_label = model(images) + total_loss = criterion_bce(pred_label, labels) + else: + pred_label, pred_id = model(images) + label_loss = criterion_bce(pred_label, labels) + id_loss = criterion_ce(pred_id, indices) + total_loss = label_loss + args.lamba * id_loss + + # backward + optimize only if in training phase + if phase == 'train': + total_loss.backward() + optimizer.step() + + preds = torch.gt(pred_label, torch.ones_like(pred_label)/2 ) + # statistics + running_loss += total_loss.item() + running_corrects += torch.sum(preds == labels.byte()).item() / num_label + if count % 100 == 0: + if not args.use_id: + print('step: ({}/{}) | label loss: {:.4f}'.format( + count*args.batch_size, dataset_sizes[phase], total_loss.item())) + else: + print('step: ({}/{}) | label loss: {:.4f} | id loss: {:.4f}'.format( + count*args.batch_size, dataset_sizes[phase], label_loss.item(), id_loss.item())) + + epoch_loss = running_loss / len(dataloaders[phase]) + epoch_acc = running_corrects / dataset_sizes[phase] + + print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) + y_loss[phase].append(epoch_loss) + y_err[phase].append(1.0-epoch_acc) + # deep copy the model + if phase == 'val': + last_model_wts = model.state_dict() + if epoch % 10 == 0: + save_network(model, epoch) + draw_curve(epoch) + + time_elapsed = time.time() - since + print('Training complete in {:.0f}m {:.0f}s'.format( + time_elapsed // 60, time_elapsed % 60)) + + # load best model weights + model.load_state_dict(last_model_wts) + save_network(model, 'last') + + +###################################################################### +# Main +# ----- +train_model(model, optimizer, exp_lr_scheduler, num_epochs=args.num_epoch)