Files
model-yolo-person-classify/main.py
T
2026-06-29 16:57:37 +09:00

140 lines
4.1 KiB
Python

import os
import json
import torch
import argparse
import requests
from io import BytesIO
from PIL import Image
from torchvision import transforms as T
from net import get_model
from clearml import Task # 1. ClearML 임포트
num_cls_dict = {'market': 30, 'duke': 23}
num_ids_dict = {'market': 751, 'duke': 702}
transforms = T.Compose([
T.Resize(size=(288, 144)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
task = Task.init(
project_name="Person_Attribute_Recognition",
task_name="model-yolo-person-classify"
)
class PredictDecoder(object):
def __init__(self, dataset):
with open('./doc/label.json', 'r') as f:
self.label_list = json.load(f)[dataset]
with open('./doc/attribute.json', 'r') as f:
self.attribute_dict = json.load(f)[dataset]
self.dataset = dataset
self.num_label = len(self.label_list)
def decode(self, pred):
pred = pred.squeeze(dim=0)
results = []
for idx in range(self.num_label):
name, choice = self.attribute_dict[self.label_list[idx]]
value = choice[pred[idx]]
if value:
results.append((name, value))
return results
def load_network(network, dataset, model_name):
save_path = os.path.join('./checkpoints', dataset, model_name, 'net_last.pth')
network.load_state_dict(torch.load(save_path))
print(f'[+] 모델 로드 완료: {save_path}')
return network
def preprocess_image(image):
src = transforms(image)
return src.unsqueeze(dim=0)
def load_image_from_url_or_path(image_source):
if image_source.startswith(('http://', 'https://')):
response = requests.get(image_source)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert('RGB')
if not os.path.isfile(image_source):
print(f"Error: Image not found: {image_source}")
exit(1)
return Image.open(image_source).convert('RGB')
def main(image, dataset='market', backbone='resnet50', use_id=False):
assert dataset in ['market', 'duke']
assert backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
model_name = f'{backbone}_nfc_id' if use_id else f'{backbone}_nfc'
num_label = num_cls_dict[dataset]
num_id = num_ids_dict[dataset]
model = get_model(model_name, num_label, use_id=use_id, num_id=num_id)
model = load_network(model, dataset, model_name)
model.eval()
src = preprocess_image(image)
with torch.no_grad():
if not use_id:
out = model.forward(src)
else:
out, _ = model.forward(src)
pred = torch.gt(out, torch.ones_like(out) / 2)
decoder = PredictDecoder(dataset)
results = decoder.decode(pred)
print("\n" + "=" * 50)
print(" Person 상세 속성 분석 결과 ")
print("=" * 50)
for name, value in results:
print(f'{name}: {value}')
print("=" * 50)
result_data = {"results": results, "status": "PASS"}
task.upload_artifact(name="final_result", artifact_object=result_data)
if __name__ == "__main__":
# python main.py --image_url "https://acai.ketidev.kr:20443/detect/image/202606/20260619_145116_image.jpg" --xywh "404,290,74,193"
parser = argparse.ArgumentParser(description="Person Attribute Recognition Agent")
parser.add_argument("--image_url", type=str) # --image_url "https://acai.ketidev.kr:20443/detect/image/202606/20260619_145116_image.jpg"
parser.add_argument("--xywh", type=str) # --xywh "x,y,w,h"; (--xywh "404,290,74,193")
args = parser.parse_args()
image_url = args.image_url
if image_url is None:
print("Error: Image URL or path is required")
exit(1)
oimg = load_image_from_url_or_path(image_url)
xywh = args.xywh
if xywh is None:
print("Error: XYWH is required")
exit(1)
xywh = xywh.split(",")
x = int(xywh[0])
y = int(xywh[1])
w = int(xywh[2])
h = int(xywh[3])
cimg = oimg.crop((x, y, x + w, y + h))
cimg.save("cropped_image.jpg")
print(f"[+] 타겟 영역 크롭 완료 -> 'cropped_image.jpg' 저장 완료.")
main(cimg)