import os import json import torch import argparse import requests from io import BytesIO from PIL import Image from torchvision import transforms as T from net import get_model from clearml import Task # 1. ClearML 임포트 num_cls_dict = {'market': 30, 'duke': 23} num_ids_dict = {'market': 751, 'duke': 702} transforms = T.Compose([ T.Resize(size=(288, 144)), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) task = Task.init( project_name="Person_Attribute_Recognition", task_name="MarketDuke_Agent" ) class PredictDecoder(object): def __init__(self, dataset): with open('./doc/label.json', 'r') as f: self.label_list = json.load(f)[dataset] with open('./doc/attribute.json', 'r') as f: self.attribute_dict = json.load(f)[dataset] self.dataset = dataset self.num_label = len(self.label_list) def decode(self, pred): pred = pred.squeeze(dim=0) results = [] for idx in range(self.num_label): name, choice = self.attribute_dict[self.label_list[idx]] value = choice[pred[idx]] if value: results.append((name, value)) return results def load_network(network, dataset, model_name): save_path = os.path.join('./checkpoints', dataset, model_name, 'net_last.pth') network.load_state_dict(torch.load(save_path)) print(f'[+] 모델 로드 완료: {save_path}') return network def preprocess_image(image): src = transforms(image) return src.unsqueeze(dim=0) def load_image_from_url_or_path(image_source): if image_source.startswith(('http://', 'https://')): response = requests.get(image_source) response.raise_for_status() return Image.open(BytesIO(response.content)).convert('RGB') if not os.path.isfile(image_source): print(f"Error: Image not found: {image_source}") exit(1) return Image.open(image_source).convert('RGB') def main(image, dataset='market', backbone='resnet50', use_id=False): assert dataset in ['market', 'duke'] assert backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121'] model_name = f'{backbone}_nfc_id' if use_id else f'{backbone}_nfc' num_label = num_cls_dict[dataset] num_id = num_ids_dict[dataset] model = get_model(model_name, num_label, use_id=use_id, num_id=num_id) model = load_network(model, dataset, model_name) model.eval() src = preprocess_image(image) with torch.no_grad(): if not use_id: out = model.forward(src) else: out, _ = model.forward(src) pred = torch.gt(out, torch.ones_like(out) / 2) decoder = PredictDecoder(dataset) results = decoder.decode(pred) print("\n" + "=" * 50) print(" Person 상세 속성 분석 결과 ") print("=" * 50) for name, value in results: print(f'{name}: {value}') print("=" * 50) result_data = {"results": results, "status": "PASS"} task.upload_artifact(name="final_result", artifact_object=result_data) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Person Attribute Recognition Agent") parser.add_argument("--image_url", type=str) parser.add_argument("--xywh", type=str) args = parser.parse_args() image_url = args.image_url if image_url is None: print("Error: Image URL or path is required") exit(1) oimg = load_image_from_url_or_path(image_url) xywh = args.xywh if xywh is None: print("Error: XYWH is required") exit(1) xywh = xywh.split(",") x = int(xywh[0]) y = int(xywh[1]) w = int(xywh[2]) h = int(xywh[3]) cimg = oimg.crop((x, y, x + w, y + h)) cimg.save("cropped_image.jpg") print(f"[+] 타겟 영역 크롭 완료 -> 'cropped_image.jpg' 저장 완료.") main(cimg)