first commit

This commit is contained in:
2026-06-29 11:21:57 +09:00
commit e75f32c80d
11 changed files with 895 additions and 0 deletions
+72
View File
@@ -0,0 +1,72 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
*.egg-info/
.eggs/
dist/
build/
*.egg
.venv/
venv/
env/
# IDE
.vscode/
.idea/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
desktop.ini
# Model weights & exports
# weights/
# *.pt
# *.pth
# *.onnx
# *.engine
# *.weights
# Images (inference input/output)
*_marked.jpg
*_marked.jpeg
*_marked.png
image.jpg
*.jpg
*.jpeg
*.png
*.bmp
*.webp
!image1.jpg
!image2.jpg
!image3.jpg
# Ultralytics outputs
runs/
output/
outputs/
results/
# ClearML
clearml.conf
*.log
# Jupyter
.ipynb_checkpoints/
# Environment & secrets
.env
.env.*
*.pem
# etc
datafolder/
doc/
net/
test_sample/
+181
View File
@@ -0,0 +1,181 @@
# Person-Attribute-Recognition-MarketDuke
A simple baseline implemented in PyTorch for **pedestrian attribute recognition** task, evaluating on Market-1501-attribute and DukeMTMC-reID-attribute dataset.
## Dataset
You can get [Market-1501-attribute](https://github.com/vana77/Market-1501_Attribute) and [DukeMTMC-reID-attribute](https://github.com/vana77/DukeMTMC-attribute) annotations from [here](https://github.com/vana77). Also you need to download Market-1501 and DukeMTMC-reID dataset.
Then, create a folder named 'attribute' under your dataset path, and put corresponding annotations into the folder.
For example,<br>
```
├── dataset
│ ├── DukeMTMC-reID
│ ├── bounding_box_test
│ ├── bounding_box_train
│ ├── query
│ ├── attribute
│ ├── duke_attribute.mat
```
## Model
Trained model are provided. You may download it from [Google Drive](https://drive.google.com/drive/folders/1JTdjuEbxSLypnfUzVuuxLj1uSKAacfd0?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1bByCxZp9bSs8YYZPbuK21A) (提取码:jpks).
You may download it and move `checkpoints` folder to your project's root directory.
## Dependencies
* Python 3.5
* PyTorch >= 0.4.1
* torchvision >= 0.2.1
* matplotlib, sklearn, prettytable (optional)
## Usage
```
python3 train.py --data-path ~/dataset --dataset [market | duke] --model resnet50 [--use-id]
python3 test.py --data-path ~/dataset --dataset [market | duke] --model resnet50 [--print-table]
python3 inference.py test_sample/test_market.jpg [--dataset market] [--model resnet50]
```
## Result
We use **binary classification** settings (considered each attribute as an independent binary classification problem), and the classification threshold is **0.5**.
***Note that the precision, recall and f1 score are denoted as '-' for some ill-defined cases.***
### Market-1501 gallery
```
+------------+----------+-----------+--------+----------+
| attribute | accuracy | precision | recall | f1 score |
+------------+----------+-----------+--------+----------+
| young | 0.998 | 0.533 | 0.267 | 0.356 |
| teenager | 0.892 | 0.927 | 0.951 | 0.939 |
| adult | 0.895 | 0.582 | 0.450 | 0.508 |
| old | 0.992 | 0.037 | 0.012 | 0.019 |
| backpack | 0.883 | 0.828 | 0.672 | 0.742 |
| bag | 0.790 | 0.608 | 0.378 | 0.467 |
| handbag | 0.893 | 0.254 | 0.065 | 0.104 |
| clothes | 0.946 | 0.956 | 0.984 | 0.970 |
| down | 0.945 | 0.968 | 0.949 | 0.959 |
| up | 0.936 | 0.938 | 0.998 | 0.967 |
| hair | 0.877 | 0.871 | 0.773 | 0.819 |
| hat | 0.982 | 0.812 | 0.505 | 0.623 |
| gender | 0.919 | 0.947 | 0.864 | 0.903 |
| upblack | 0.954 | 0.859 | 0.790 | 0.823 |
| upwhite | 0.926 | 0.846 | 0.882 | 0.863 |
| upred | 0.974 | 0.904 | 0.840 | 0.871 |
| uppurple | 0.985 | 0.703 | 0.815 | 0.755 |
| upyellow | 0.976 | 0.895 | 0.836 | 0.865 |
| upgray | 0.909 | 0.852 | 0.391 | 0.537 |
| upblue | 0.946 | 0.868 | 0.420 | 0.566 |
| upgreen | 0.966 | 0.790 | 0.713 | 0.750 |
| downblack | 0.879 | 0.815 | 0.889 | 0.850 |
| downwhite | 0.956 | 0.608 | 0.550 | 0.578 |
| downpink | 0.989 | 0.795 | 0.782 | 0.788 |
| downpurple | 1.000 | - | - | - |
| downyellow | 0.999 | 0.000 | 0.000 | 0.000 |
| downgray | 0.878 | 0.756 | 0.443 | 0.559 |
| downblue | 0.861 | 0.762 | 0.446 | 0.563 |
| downgreen | 0.978 | 0.766 | 0.295 | 0.426 |
| downbrown | 0.958 | 0.754 | 0.590 | 0.662 |
+------------+----------+-----------+--------+----------+
Average accuracy: 0.9361
Average f1 score: 0.6492
```
### DukeMTMC-ReID gallery
```
+-----------+----------+-----------+--------+----------+
| attribute | accuracy | precision | recall | f1 score |
+-----------+----------+-----------+--------+----------+
| backpack | 0.829 | 0.794 | 0.926 | 0.855 |
| bag | 0.836 | 0.496 | 0.287 | 0.364 |
| handbag | 0.935 | 0.469 | 0.073 | 0.126 |
| boots | 0.905 | 0.784 | 0.791 | 0.787 |
| gender | 0.858 | 0.806 | 0.828 | 0.817 |
| hat | 0.898 | 0.883 | 0.680 | 0.768 |
| shoes | 0.916 | 0.756 | 0.414 | 0.535 |
| top | 0.893 | 0.590 | 0.381 | 0.463 |
| upblack | 0.821 | 0.827 | 0.903 | 0.864 |
| upwhite | 0.959 | 0.750 | 0.509 | 0.606 |
| upred | 0.973 | 0.745 | 0.649 | 0.694 |
| uppurple | 0.995 | 0.258 | 0.123 | 0.167 |
| upgray | 0.900 | 0.611 | 0.333 | 0.432 |
| upblue | 0.943 | 0.766 | 0.519 | 0.619 |
| upgreen | 0.975 | 0.463 | 0.403 | 0.431 |
| upbrown | 0.980 | 0.481 | 0.328 | 0.390 |
| downblack | 0.787 | 0.740 | 0.807 | 0.772 |
| downwhite | 0.945 | 0.771 | 0.395 | 0.522 |
| downred | 0.991 | 0.739 | 0.645 | 0.689 |
| downgray | 0.927 | 0.471 | 0.238 | 0.317 |
| downblue | 0.807 | 0.741 | 0.669 | 0.703 |
| downgreen | 0.997 | - | - | - |
| downbrown | 0.979 | 0.871 | 0.594 | 0.706 |
+-----------+----------+-----------+--------+----------+
Average accuracy: 0.9152
Average f1 score: 0.5739
```
### Inference
```
>> python inference.py test_sample/test_market.jpg --dataset market
age: teenager
carrying backpack: no
carrying bag: no
carrying handbag: no
type of lower-body clothing: dress
length of lower-body clothing: short
sleeve length: short sleeve
hair length: long hair
wearing hat: no
gender: female
color of upper-body clothing: white
color of lower-body clothing: white
>> python inference.py test_sample/test_duke.jpg --dataset duke
carrying backpack: no
carrying bag: yes
carrying handbag: no
wearing boots: no
gender: male
wearing hat: no
color of shoes: dark
length of upper-body clothing: short upper body clothing
color of upper-body clothing: black
color of lower-body clothing: blue
```
## Update
*20-06-03: Added **identity loss** for joint optimization; Adjusted the learning rate for better performace.*
*20-06-03: Updated **test.py**, settled the issue of ill-defined metrics.*
*19-09-16: Updated **inference.py**, fixed the error caused by missing data-transform.*
*19-09-06: Updated **test.py**, added **F1 score** for evaluating.*
*19-09-03: Added **inference.py**, thanks @ViswanathaReddyGajjala.*
*19-08-23: Released trained models.*
*19-01-09: Fixed the error caused by an update of market and duke attribute dataset.*
## FAQ
### 1. Why attribute order in import_Market1501Attribute.py is different for train and test data?
The label order in import_Market1501Attribute.py is consistent with the attribute order of the dataset.
You can load market_attribute.mat in MATLAB and print "market_attribute.train" or "market_attribute.test" to obtain these orders.
### 2. Why predictions in the Market-1501 dataset have 30 attributes instead of 27?
This repo consider attribute prediction as multiple binary classification, but some attribute have more than two categories.
For example, attribute 'age' in Market-1501 has four categories: young(1), teenager(2), adult(3), old(4). So it can be split into four attributes: 'young', 'teenager', 'adult' and 'old'.
That's why preds of Market-1501 has 30 attributes.
## Reference
*[1] Lin Y, Zheng L, Zheng Z, et al. Improving person re-identification by attribute and identity learning[J]. Pattern Recognition, 2019.*
Binary file not shown.
Binary file not shown.
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 128 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

+96
View File
@@ -0,0 +1,96 @@
import os
import json
import torch
import argparse
from PIL import Image
from torchvision import transforms as T
from net import get_model
######################################################################
# Settings
# ---------
dataset_dict = {
'market' : 'Market-1501',
'duke' : 'DukeMTMC-reID',
}
num_cls_dict = { 'market':30, 'duke':23 }
num_ids_dict = { 'market':751, 'duke':702 }
transforms = T.Compose([
T.Resize(size=(288, 144)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
######################################################################
# Argument
# ---------
parser = argparse.ArgumentParser()
parser.add_argument('image_path', help='Path to test image')
parser.add_argument('--dataset', default='market', type=str, help='dataset')
parser.add_argument('--backbone', default='resnet50', type=str, help='model')
parser.add_argument('--use-id', action='store_true', help='use identity loss')
args = parser.parse_args()
assert args.dataset in ['market', 'duke']
assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone)
num_label, num_id = num_cls_dict[args.dataset], num_ids_dict[args.dataset]
######################################################################
# Model and Data
# ---------
def load_network(network):
save_path = os.path.join('./checkpoints', args.dataset, model_name, 'net_last.pth')
network.load_state_dict(torch.load(save_path))
print('Resume model from {}'.format(save_path))
return network
def load_image(path):
src = Image.open(path)
src = transforms(src)
src = src.unsqueeze(dim=0)
return src
model = get_model(model_name, num_label, use_id=args.use_id, num_id=num_id)
model = load_network(model)
model.eval()
src = load_image(args.image_path)
######################################################################
# Inference
# ---------
class predict_decoder(object):
def __init__(self, dataset):
with open('./doc/label.json', 'r') as f:
self.label_list = json.load(f)[dataset]
with open('./doc/attribute.json', 'r') as f:
self.attribute_dict = json.load(f)[dataset]
self.dataset = dataset
self.num_label = len(self.label_list)
def decode(self, pred):
pred = pred.squeeze(dim=0)
for idx in range(self.num_label):
name, chooce = self.attribute_dict[self.label_list[idx]]
if chooce[pred[idx]]:
print('{}: {}'.format(name, chooce[pred[idx]]))
if not args.use_id:
out = model.forward(src)
else:
out, _ = model.forward(src)
pred = torch.gt(out, torch.ones_like(out)/2 ) # threshold=0.5
Dec = predict_decoder(args.dataset)
Dec.decode(pred)
+135
View File
@@ -0,0 +1,135 @@
import os
import json
import torch
import argparse
import requests
from io import BytesIO
from PIL import Image
from torchvision import transforms as T
from net import get_model
from clearml import Task # 1. ClearML 임포트
num_cls_dict = {'market': 30, 'duke': 23}
num_ids_dict = {'market': 751, 'duke': 702}
transforms = T.Compose([
T.Resize(size=(288, 144)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
task = Task.init(
project_name="Person_Attribute_Recognition",
task_name="MarketDuke_Agent"
)
class PredictDecoder(object):
def __init__(self, dataset):
with open('./doc/label.json', 'r') as f:
self.label_list = json.load(f)[dataset]
with open('./doc/attribute.json', 'r') as f:
self.attribute_dict = json.load(f)[dataset]
self.dataset = dataset
self.num_label = len(self.label_list)
def decode(self, pred):
pred = pred.squeeze(dim=0)
results = []
for idx in range(self.num_label):
name, choice = self.attribute_dict[self.label_list[idx]]
value = choice[pred[idx]]
if value:
results.append((name, value))
return results
def load_network(network, dataset, model_name):
save_path = os.path.join('./checkpoints', dataset, model_name, 'net_last.pth')
network.load_state_dict(torch.load(save_path))
print(f'[+] 모델 로드 완료: {save_path}')
return network
def preprocess_image(image):
src = transforms(image)
return src.unsqueeze(dim=0)
def load_image_from_url_or_path(image_source):
if image_source.startswith(('http://', 'https://')):
response = requests.get(image_source)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert('RGB')
if not os.path.isfile(image_source):
print(f"Error: Image not found: {image_source}")
exit(1)
return Image.open(image_source).convert('RGB')
def main(image, dataset='market', backbone='resnet50', use_id=False):
assert dataset in ['market', 'duke']
assert backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
model_name = f'{backbone}_nfc_id' if use_id else f'{backbone}_nfc'
num_label = num_cls_dict[dataset]
num_id = num_ids_dict[dataset]
model = get_model(model_name, num_label, use_id=use_id, num_id=num_id)
model = load_network(model, dataset, model_name)
model.eval()
src = preprocess_image(image)
with torch.no_grad():
if not use_id:
out = model.forward(src)
else:
out, _ = model.forward(src)
pred = torch.gt(out, torch.ones_like(out) / 2)
decoder = PredictDecoder(dataset)
results = decoder.decode(pred)
print("\n" + "=" * 50)
print(" Person 상세 속성 분석 결과 ")
print("=" * 50)
for name, value in results:
print(f'{name}: {value}')
print("=" * 50)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Person Attribute Recognition Agent")
parser.add_argument("--image_url", type=str)
parser.add_argument("--xywh", type=str)
args = parser.parse_args()
image_url = args.image_url
if image_url is None:
print("Error: Image URL or path is required")
exit(1)
oimg = load_image_from_url_or_path(image_url)
xywh = args.xywh
if xywh is None:
print("Error: XYWH is required")
exit(1)
xywh = xywh.split(",")
x = int(xywh[0])
y = int(xywh[1])
w = int(xywh[2])
h = int(xywh[3])
cimg = oimg.crop((x, y, x + w, y + h))
cimg.save("cropped_image.jpg")
print(f"[+] 타겟 영역 크롭 완료 -> 'cropped_image.jpg' 저장 완료.")
main(cimg)
+200
View File
@@ -0,0 +1,200 @@
import os
import argparse
import scipy.io
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# from sklearn.exceptions import UndefinedMetricWarning
from datafolder.folder import Test_Dataset
from net import get_model
######################################################################
# Settings
# ---------
use_gpu = True
dataset_dict = {
'market' : 'Market-1501',
'duke' : 'DukeMTMC-reID',
}
num_cls_dict = { 'market':30, 'duke':23 }
num_ids_dict = { 'market':751, 'duke':702 }
######################################################################
# Argument
# ---------
parser = argparse.ArgumentParser(description='Testing')
parser.add_argument('--data-path', default='/home/xxx/reid/', type=str, help='path to the dataset')
parser.add_argument('--dataset', default='market', type=str, help='dataset')
parser.add_argument('--backbone', default='resnet50', type=str, help='model')
parser.add_argument('--batch-size', default=50, type=int, help='batch size')
parser.add_argument('--num-epoch', default=60, type=int, help='num of epoch')
parser.add_argument('--num-workers', default=2, type=int, help='num_workers')
parser.add_argument('--which-epoch',default='last', type=str, help='0,1,2,3...or last')
parser.add_argument('--print-table',action='store_true', help='print results with table format')
parser.add_argument('--use-id', action='store_true', help='use identity loss')
args = parser.parse_args()
assert args.dataset in ['market', 'duke']
assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
dataset_name = dataset_dict[args.dataset]
model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone)
data_dir = args.data_path
model_dir = os.path.join('./checkpoints', args.dataset, model_name)
result_dir = os.path.join('./result', args.dataset, model_name)
if not os.path.isdir(result_dir):
os.makedirs(result_dir)
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
######################################################################
# Function
# ---------
def load_network(network):
save_path = os.path.join(model_dir,'net_%s.pth'%args.which_epoch)
network.load_state_dict(torch.load(save_path))
print('Resume model from {}'.format(save_path))
return network
def get_dataloader():
image_datasets = {}
image_datasets['gallery'] = Test_Dataset(data_dir, dataset_name=dataset_name, query_gallery='gallery')
image_datasets['query'] = Test_Dataset(data_dir, dataset_name=dataset_name, query_gallery='query')
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers)
for x in ['gallery', 'query']}
return dataloaders
def check_metric_vaild(y_pred, y_true):
if y_true.min() == y_true.max() == 0: # precision
return False
if y_pred.min() == y_pred.max() == 0: # recall
return False
return True
######################################################################
# Load Data
# ---------
# Note that we only perform evaluation on gallery set.
test_loader = get_dataloader()['gallery']
attribute_list = test_loader.dataset.labels()
num_label = len(attribute_list)
num_sample = len(test_loader.dataset)
num_id = num_ids_dict[args.dataset]
######################################################################
# Model
# ---------
model = get_model(model_name, num_label, use_id=args.use_id, num_id=num_id)
model = load_network(model)
if use_gpu:
model = model.cuda()
model.train(False) # Set model to evaluate mode
######################################################################
# Testing
# ---------
preds_tensor = np.empty(shape=[0, num_label], dtype=np.byte) # shape = (num_sample, num_label)
labels_tensor = np.empty(shape=[0, num_label], dtype=np.byte) # shape = (num_sample, num_label)
# Iterate over data.
with torch.no_grad():
for count, (images, labels, ids, file_name) in enumerate(test_loader):
# move input to GPU
if use_gpu:
images = images.cuda()
# forward
if not args.use_id:
pred_label = model(images)
else:
pred_label, _ = model(images)
preds = torch.gt(pred_label, torch.ones_like(pred_label)/2)
# transform to numpy format
labels = labels.cpu().numpy()
preds = preds.cpu().numpy()
# append
preds_tensor = np.append(preds_tensor, preds, axis=0)
labels_tensor = np.append(labels_tensor, labels, axis=0)
# print info
if count*args.batch_size % 5000 == 0:
print('Step: {}/{}'.format(count*args.batch_size, num_sample))
# Evaluation.
accuracy_list = []
precision_list = []
recall_list = []
f1_score_list = []
average_precision = 0.0
average_recall = 0.0
average_f1score = 0.0
valid_count = 0
for i, name in enumerate(attribute_list):
y_true, y_pred = labels_tensor[:, i], preds_tensor[:, i]
accuracy_list.append(accuracy_score(y_true, y_pred))
if check_metric_vaild(y_pred, y_true): # exclude ill-defined cases
precision_list.append(precision_score(y_true, y_pred, average='binary'))
recall_list.append(recall_score(y_true, y_pred, average='binary'))
f1_score_list.append(f1_score(y_true, y_pred, average='binary'))
average_precision += precision_list[-1]
average_recall += recall_list[-1]
average_f1score += f1_score_list[-1]
valid_count += 1
else:
precision_list.append(-1)
recall_list.append(-1)
f1_score_list.append(-1)
average_acc = np.mean(accuracy_list)
average_precision = average_precision / valid_count
average_recall = average_recall / valid_count
average_f1score = average_f1score / valid_count
######################################################################
# Print
# ---------
print("\n"
"The Precision, Recall and F-score are ignored for some ill-defined cases."
"\n")
if args.print_table:
from prettytable import PrettyTable
table = PrettyTable(['attribute', 'accuracy', 'precision', 'recall', 'f1 score'])
for i, name in enumerate(attribute_list):
table.add_row([name,
'%.3f' % accuracy_list[i],
'%.3f' % precision_list[i] if precision_list[i] >= 0.0 else '-',
'%.3f' % recall_list[i] if recall_list[i] >= 0.0 else '-',
'%.3f' % f1_score_list[i] if f1_score_list[i] >= 0.0 else '-',
])
print(table)
print('Average accuracy: {:.4f}'.format(average_acc))
# print('Average precision: {:.4f}'.format(average_precision))
# print('Average recall: {:.4f}'.format(average_recall))
print('Average f1 score: {:.4f}'.format(average_f1score))
# Save results.
result = {
'average_acc' : average_acc,
'average_f1score' : average_f1score,
'accuracy_list' : accuracy_list,
'precision_list' : precision_list,
'recall_list' : recall_list,
'f1_score_list' : f1_score_list,
}
scipy.io.savemat(os.path.join(result_dir, 'acc.mat'), result)
+211
View File
@@ -0,0 +1,211 @@
# !/usr/local/bin/python3
import os
import time
import argparse
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from datafolder.folder import Train_Dataset
from net import get_model
######################################################################
# Settings
# --------
use_gpu = True
dataset_dict = {
'market' : 'Market-1501',
'duke' : 'DukeMTMC-reID',
}
######################################################################
# Argument
# --------
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--data-path', default='/path/to/dataset', type=str, help='path to the dataset')
parser.add_argument('--dataset', default='market', type=str, help='dataset: market, duke')
parser.add_argument('--backbone', default='resnet50', type=str, help='backbone: resnet50, resnet34, resnet18, densenet121')
parser.add_argument('--batch-size', default=32, type=int, help='batch size')
parser.add_argument('--num-epoch', default=60, type=int, help='num of epoch')
parser.add_argument('--num-workers', default=2, type=int, help='num_workers')
parser.add_argument('--use-id', action='store_true', help='use identity loss')
parser.add_argument('--lamba', default=1.0, type=float, help='weight of id loss')
args = parser.parse_args()
assert args.dataset in ['market', 'duke']
assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
dataset_name = dataset_dict[args.dataset]
model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone)
data_dir = args.data_path
model_dir = os.path.join('./checkpoints', args.dataset, model_name)
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
######################################################################
# Function
# --------
def save_network(network, epoch_label):
save_filename = 'net_%s.pth'% epoch_label
save_path = os.path.join(model_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if use_gpu:
network.cuda()
print('Save model to {}'.format(save_path))
######################################################################
# Draw Curve
#-----------
x_epoch = []
y_loss = {} # loss history
y_loss['train'] = []
y_loss['val'] = []
y_err = {}
y_err['train'] = []
y_err['val'] = []
fig = plt.figure()
ax0 = fig.add_subplot(121, title="loss")
ax1 = fig.add_subplot(122, title="top1err")
def draw_curve(current_epoch):
x_epoch.append(current_epoch)
ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
if current_epoch == 0:
ax0.legend()
ax1.legend()
fig.savefig( os.path.join(model_dir, 'train.jpg'))
######################################################################
# DataLoader
# ---------
image_datasets = {}
image_datasets['train'] = Train_Dataset(data_dir, dataset_name=dataset_name, train_val='train')
image_datasets['val'] = Train_Dataset(data_dir, dataset_name=dataset_name, train_val='query')
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers, drop_last=True)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# images, indices, labels, ids, cams, names = next(iter(dataloaders['train']))
num_label = image_datasets['train'].num_label()
num_id = image_datasets['train'].num_id()
labels_list = image_datasets['train'].labels()
######################################################################
# Model and Optimizer
# ------------------
model = get_model(model_name, num_label, args.use_id, num_id=num_id)
if use_gpu:
model = model.cuda()
# loss
criterion_bce = nn.BCELoss()
criterion_ce = nn.CrossEntropyLoss()
# optimizer
ignored_params = list(map(id, model.features.parameters()))
classifier_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.SGD([
{'params': model.features.parameters(), 'lr': 0.01},
{'params': classifier_params, 'lr': 0.1},
], momentum=0.9, weight_decay=5e-4, nesterov=True)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
######################################################################
# Training the model
# ------------------
def train_model(model, optimizer, scheduler, num_epochs):
since = time.time()
for epoch in range(1, num_epochs+1):
print('Epoch {}/{}'.format(epoch, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train(True) # Set model to training mode
else:
model.train(False) # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for count, (images, indices, labels, ids, cams, names) in enumerate(dataloaders[phase]):
# get the inputs
labels = labels.float()
if use_gpu:
images = images.cuda()
labels = labels.cuda()
indices = indices.cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward
if not args.use_id:
pred_label = model(images)
total_loss = criterion_bce(pred_label, labels)
else:
pred_label, pred_id = model(images)
label_loss = criterion_bce(pred_label, labels)
id_loss = criterion_ce(pred_id, indices)
total_loss = label_loss + args.lamba * id_loss
# backward + optimize only if in training phase
if phase == 'train':
total_loss.backward()
optimizer.step()
preds = torch.gt(pred_label, torch.ones_like(pred_label)/2 )
# statistics
running_loss += total_loss.item()
running_corrects += torch.sum(preds == labels.byte()).item() / num_label
if count % 100 == 0:
if not args.use_id:
print('step: ({}/{}) | label loss: {:.4f}'.format(
count*args.batch_size, dataset_sizes[phase], total_loss.item()))
else:
print('step: ({}/{}) | label loss: {:.4f} | id loss: {:.4f}'.format(
count*args.batch_size, dataset_sizes[phase], label_loss.item(), id_loss.item()))
epoch_loss = running_loss / len(dataloaders[phase])
epoch_acc = running_corrects / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
y_loss[phase].append(epoch_loss)
y_err[phase].append(1.0-epoch_acc)
# deep copy the model
if phase == 'val':
last_model_wts = model.state_dict()
if epoch % 10 == 0:
save_network(model, epoch)
draw_curve(epoch)
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
# load best model weights
model.load_state_dict(last_model_wts)
save_network(model, 'last')
######################################################################
# Main
# -----
train_model(model, optimizer, exp_lr_scheduler, num_epochs=args.num_epoch)