136 lines
3.7 KiB
Python
136 lines
3.7 KiB
Python
import os
|
|
import json
|
|
import torch
|
|
import argparse
|
|
import requests
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from torchvision import transforms as T
|
|
from net import get_model
|
|
|
|
from clearml import Task # 1. ClearML 임포트
|
|
|
|
num_cls_dict = {'market': 30, 'duke': 23}
|
|
num_ids_dict = {'market': 751, 'duke': 702}
|
|
|
|
transforms = T.Compose([
|
|
T.Resize(size=(288, 144)),
|
|
T.ToTensor(),
|
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
task = Task.init(
|
|
project_name="Person_Attribute_Recognition",
|
|
task_name="MarketDuke_Agent"
|
|
)
|
|
|
|
|
|
class PredictDecoder(object):
|
|
|
|
def __init__(self, dataset):
|
|
with open('./doc/label.json', 'r') as f:
|
|
self.label_list = json.load(f)[dataset]
|
|
with open('./doc/attribute.json', 'r') as f:
|
|
self.attribute_dict = json.load(f)[dataset]
|
|
self.dataset = dataset
|
|
self.num_label = len(self.label_list)
|
|
|
|
def decode(self, pred):
|
|
pred = pred.squeeze(dim=0)
|
|
results = []
|
|
for idx in range(self.num_label):
|
|
name, choice = self.attribute_dict[self.label_list[idx]]
|
|
value = choice[pred[idx]]
|
|
if value:
|
|
results.append((name, value))
|
|
return results
|
|
|
|
|
|
def load_network(network, dataset, model_name):
|
|
save_path = os.path.join('./checkpoints', dataset, model_name, 'net_last.pth')
|
|
network.load_state_dict(torch.load(save_path))
|
|
print(f'[+] 모델 로드 완료: {save_path}')
|
|
return network
|
|
|
|
|
|
def preprocess_image(image):
|
|
src = transforms(image)
|
|
return src.unsqueeze(dim=0)
|
|
|
|
|
|
def load_image_from_url_or_path(image_source):
|
|
if image_source.startswith(('http://', 'https://')):
|
|
response = requests.get(image_source)
|
|
response.raise_for_status()
|
|
return Image.open(BytesIO(response.content)).convert('RGB')
|
|
|
|
if not os.path.isfile(image_source):
|
|
print(f"Error: Image not found: {image_source}")
|
|
exit(1)
|
|
|
|
return Image.open(image_source).convert('RGB')
|
|
|
|
|
|
def main(image, dataset='market', backbone='resnet50', use_id=False):
|
|
assert dataset in ['market', 'duke']
|
|
assert backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
|
|
|
|
model_name = f'{backbone}_nfc_id' if use_id else f'{backbone}_nfc'
|
|
num_label = num_cls_dict[dataset]
|
|
num_id = num_ids_dict[dataset]
|
|
|
|
model = get_model(model_name, num_label, use_id=use_id, num_id=num_id)
|
|
model = load_network(model, dataset, model_name)
|
|
model.eval()
|
|
|
|
src = preprocess_image(image)
|
|
|
|
with torch.no_grad():
|
|
if not use_id:
|
|
out = model.forward(src)
|
|
else:
|
|
out, _ = model.forward(src)
|
|
|
|
pred = torch.gt(out, torch.ones_like(out) / 2)
|
|
|
|
decoder = PredictDecoder(dataset)
|
|
results = decoder.decode(pred)
|
|
|
|
print("\n" + "=" * 50)
|
|
print(" Person 상세 속성 분석 결과 ")
|
|
print("=" * 50)
|
|
for name, value in results:
|
|
print(f'{name}: {value}')
|
|
print("=" * 50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Person Attribute Recognition Agent")
|
|
parser.add_argument("--image_url", type=str)
|
|
parser.add_argument("--xywh", type=str)
|
|
|
|
args = parser.parse_args()
|
|
image_url = args.image_url
|
|
if image_url is None:
|
|
print("Error: Image URL or path is required")
|
|
exit(1)
|
|
|
|
oimg = load_image_from_url_or_path(image_url)
|
|
|
|
xywh = args.xywh
|
|
if xywh is None:
|
|
print("Error: XYWH is required")
|
|
exit(1)
|
|
|
|
xywh = xywh.split(",")
|
|
x = int(xywh[0])
|
|
y = int(xywh[1])
|
|
w = int(xywh[2])
|
|
h = int(xywh[3])
|
|
|
|
cimg = oimg.crop((x, y, x + w, y + h))
|
|
cimg.save("cropped_image.jpg")
|
|
print(f"[+] 타겟 영역 크롭 완료 -> 'cropped_image.jpg' 저장 완료.")
|
|
|
|
main(cimg)
|