dk - start.py change to main.py
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import argparse
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from torchvision import transforms as T
|
||||
from net import get_model
|
||||
|
||||
from clearml import Task # 1. ClearML 임포트
|
||||
|
||||
num_cls_dict = {'market': 30, 'duke': 23}
|
||||
num_ids_dict = {'market': 751, 'duke': 702}
|
||||
|
||||
transforms = T.Compose([
|
||||
T.Resize(size=(288, 144)),
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
task = Task.init(
|
||||
project_name="Person_Attribute_Recognition",
|
||||
task_name="MarketDuke_Agent"
|
||||
)
|
||||
|
||||
|
||||
class PredictDecoder(object):
|
||||
|
||||
def __init__(self, dataset):
|
||||
with open('./doc/label.json', 'r') as f:
|
||||
self.label_list = json.load(f)[dataset]
|
||||
with open('./doc/attribute.json', 'r') as f:
|
||||
self.attribute_dict = json.load(f)[dataset]
|
||||
self.dataset = dataset
|
||||
self.num_label = len(self.label_list)
|
||||
|
||||
def decode(self, pred):
|
||||
pred = pred.squeeze(dim=0)
|
||||
results = []
|
||||
for idx in range(self.num_label):
|
||||
name, choice = self.attribute_dict[self.label_list[idx]]
|
||||
value = choice[pred[idx]]
|
||||
if value:
|
||||
results.append((name, value))
|
||||
return results
|
||||
|
||||
|
||||
def load_network(network, dataset, model_name):
|
||||
save_path = os.path.join('./checkpoints', dataset, model_name, 'net_last.pth')
|
||||
network.load_state_dict(torch.load(save_path))
|
||||
print(f'[+] 모델 로드 완료: {save_path}')
|
||||
return network
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
src = transforms(image)
|
||||
return src.unsqueeze(dim=0)
|
||||
|
||||
|
||||
def load_image_from_url_or_path(image_source):
|
||||
if image_source.startswith(('http://', 'https://')):
|
||||
response = requests.get(image_source)
|
||||
response.raise_for_status()
|
||||
return Image.open(BytesIO(response.content)).convert('RGB')
|
||||
|
||||
if not os.path.isfile(image_source):
|
||||
print(f"Error: Image not found: {image_source}")
|
||||
exit(1)
|
||||
|
||||
return Image.open(image_source).convert('RGB')
|
||||
|
||||
|
||||
def main(image, dataset='market', backbone='resnet50', use_id=False):
|
||||
assert dataset in ['market', 'duke']
|
||||
assert backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121']
|
||||
|
||||
model_name = f'{backbone}_nfc_id' if use_id else f'{backbone}_nfc'
|
||||
num_label = num_cls_dict[dataset]
|
||||
num_id = num_ids_dict[dataset]
|
||||
|
||||
model = get_model(model_name, num_label, use_id=use_id, num_id=num_id)
|
||||
model = load_network(model, dataset, model_name)
|
||||
model.eval()
|
||||
|
||||
src = preprocess_image(image)
|
||||
|
||||
with torch.no_grad():
|
||||
if not use_id:
|
||||
out = model.forward(src)
|
||||
else:
|
||||
out, _ = model.forward(src)
|
||||
|
||||
pred = torch.gt(out, torch.ones_like(out) / 2)
|
||||
|
||||
decoder = PredictDecoder(dataset)
|
||||
results = decoder.decode(pred)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print(" Person 상세 속성 분석 결과 ")
|
||||
print("=" * 50)
|
||||
for name, value in results:
|
||||
print(f'{name}: {value}')
|
||||
print("=" * 50)
|
||||
|
||||
result_data = {"results": results, "status": "PASS"}
|
||||
task.upload_artifact(name="final_result", artifact_object=result_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Person Attribute Recognition Agent")
|
||||
parser.add_argument("--image_url", type=str)
|
||||
parser.add_argument("--xywh", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
image_url = args.image_url
|
||||
if image_url is None:
|
||||
print("Error: Image URL or path is required")
|
||||
exit(1)
|
||||
|
||||
oimg = load_image_from_url_or_path(image_url)
|
||||
|
||||
xywh = args.xywh
|
||||
if xywh is None:
|
||||
print("Error: XYWH is required")
|
||||
exit(1)
|
||||
|
||||
xywh = xywh.split(",")
|
||||
x = int(xywh[0])
|
||||
y = int(xywh[1])
|
||||
w = int(xywh[2])
|
||||
h = int(xywh[3])
|
||||
|
||||
cimg = oimg.crop((x, y, x + w, y + h))
|
||||
cimg.save("cropped_image.jpg")
|
||||
print(f"[+] 타겟 영역 크롭 완료 -> 'cropped_image.jpg' 저장 완료.")
|
||||
|
||||
main(cimg)
|
||||
Reference in New Issue
Block a user