Commit eef4e78c authored by Lamping, Christian's avatar Lamping, Christian
Browse files

Update utils.py

parent c72b18dd
......@@ -1030,3 +1030,68 @@ def init_distributed_mode(args):
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def get_prediction(model, img, threshold):
"""
This function generates a prediction for a given model an image, then
post-process the prediction given a threshold and the name of the classes.
"""
CATEGORY_NAMES = [
'__background__', 'hen'
]
device = torch.device('cuda:0') # Default CUDA device
# Uncomment if loading image as a path
# img = Image.open(img_path)
# transform = T.Compose([T.ToTensor()])
# img = transform(img).to(device)
pred = model([img.to(device)])
masks, pred_boxes, pred_class, features, mask_features = None, None, None, None, None
pred_score = list(pred[0]['scores'].cpu().detach().numpy())
try:
pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]
masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()
if len(masks.shape)<3:
masks = np.expand_dims(masks, 0)
pred_class = [CATEGORY_NAMES[i] for i in list(pred[0]['labels'].cpu().numpy())]
pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].cpu().detach().numpy())]
masks = masks[:pred_t+1]
pred_boxes = pred_boxes[:pred_t+1]
pred_class = pred_class[:pred_t+1]
pred_score = pred_score[:pred_t+1]
except:
print("No objects detected")
pass
return masks, pred_boxes, pred_class, pred_score
def random_colour_masks(image):
"""
Give a mask a random color
"""
colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
color = colours[random.randrange(0,10)]
r[image == 1], g[image == 1], b[image == 1] = color
coloured_mask = np.stack([r, g, b], axis=2)
return coloured_mask, color
def instance_segmentation_api(model, img, threshold=0.5, rect_th=1, text_size=1, text_th=1):
"""
Given an image and a model, this function predicts the outputs and apply
them into the image
"""
masks, boxes, pred_cls, pred_score = get_prediction(model, img, threshold)
img = np.array(transforms.ToPILImage(mode='RGB')(img))
desired_masks = []; desired_boxes = []; desired_cls = []; desired_features = []; desired_mask_features = []; desired_scores = []
if masks is not None:
for i in range(len(masks)):
rgb_mask, color = random_colour_masks(masks[i])
img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
cv2.rectangle(img, boxes[i][0], boxes[i][1], color=color, thickness=rect_th)
text = pred_cls[i] + ': ' + str(round(pred_score[i], 2))
cv2.putText(img, text, boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, color=color, thickness=text_th)
return Image.fromarray(img), masks, boxes, pred_cls, pred_score
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment