Commit 19805b84 authored by Lamping, Christian's avatar Lamping, Christian
Browse files

Update utils.py

parent 20ac16ca
...@@ -698,7 +698,7 @@ def display_instances(image, boxes, masks, class_ids, class_names, ...@@ -698,7 +698,7 @@ def display_instances(image, boxes, masks, class_ids, class_names,
def extract_bboxes(mask): def extract_bboxes(mask):
"""Compute bounding boxes from masks. """Compute bounding boxes from masks.
mask: [height, width, num_instances]. Mask pixels are either 1 or 0. mask: [height, width, num_instances]. Mask pixels are either 1 or 0.
Returns: bbox array [num_instances, (y1, x1, y2, x2)]. Returns: bbox array [num_instances, (x1, y1, x2, y2)].
""" """
mask = mask.transpose((1,2,0)) mask = mask.transpose((1,2,0))
boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32) boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32)
...@@ -1034,41 +1034,6 @@ def init_distributed_mode(args): ...@@ -1034,41 +1034,6 @@ def init_distributed_mode(args):
torch.distributed.barrier() torch.distributed.barrier()
setup_for_distributed(args.rank == 0) setup_for_distributed(args.rank == 0)
def get_prediction(model, img, threshold):
"""
This function generates a prediction for a given model an image, then
post-process the prediction given a threshold and the name of the classes.
"""
CATEGORY_NAMES = [
'__background__', 'hen'
]
device = torch.device('cuda:0') # Default CUDA device
# Uncomment if loading image as a path
# img = Image.open(img_path)
# transform = T.Compose([T.ToTensor()])
# img = transform(img).to(device)
pred = model([img.to(device)])
masks, pred_boxes, pred_class, features, mask_features = None, None, None, None, None
pred_score = list(pred[0]['scores'].cpu().detach().numpy())
try:
pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]
masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()
if len(masks.shape)<3:
masks = np.expand_dims(masks, 0)
pred_class = [CATEGORY_NAMES[i] for i in list(pred[0]['labels'].cpu().numpy())]
pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].cpu().detach().numpy())]
masks = masks[:pred_t+1]
pred_boxes = pred_boxes[:pred_t+1]
pred_class = pred_class[:pred_t+1]
pred_score = pred_score[:pred_t+1]
except:
print("No objects detected")
pass
return masks, pred_boxes, pred_class, pred_score
def random_colour_masks(image): def random_colour_masks(image):
""" """
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment