Skip to content

Commit

Permalink
Merge branch 'dev-MVC' of github.com:SystemsGenetics/granny into dev-MVC
Browse files Browse the repository at this point in the history
  • Loading branch information
spficklin committed Jul 12, 2024
2 parents 21cbfb1 + cbbdf5f commit d26b839
Showing 1 changed file with 32 additions and 35 deletions.
67 changes: 32 additions & 35 deletions Granny/Analyses/Segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, List, Tuple
from urllib import request

import matplotlib.pyplot as plt
import cv2
import numpy as np
import pandas as pd
from Granny.Analyses.Analysis import Analysis
Expand All @@ -33,7 +33,6 @@
from Granny.Models.IO.RGBImageFile import RGBImageFile
from Granny.Models.Values.FileNameValue import FileNameValue
from Granny.Models.Values.ImageListValue import ImageListValue
from matplotlib import patches
from numpy.typing import NDArray


Expand All @@ -42,7 +41,7 @@ class SegmentationConfig:
MODELS: Dict[str, Dict[str, str]] = {
"pome_fruit-v1_0": {
"full_name": "granny-v1_0-pome_fruit-v1_0.pt",
"url": "https://osf.io/dqzyn/download/",
"url": "https://osf.io/vyfhm/download/",
}
}

Expand Down Expand Up @@ -102,12 +101,12 @@ def __init__(self):
"tray_infos",
)
)
self.full_images = ImageListValue(
self.masked_images = ImageListValue(
"f_img",
"full_masked_image",
"The output directory where the full-masked images are written.",
)
self.full_images.setValue(
self.masked_images.setValue(
os.path.join(
os.curdir,
"results",
Expand All @@ -126,7 +125,6 @@ def _getModelUrl(self, model_name: str):
model_url = ""
try:
model_url = self.models[model_name]["url"]
print(f"Model URL: {model_url}")
except KeyError:
print(f"Key '{model_name}' not found in configuration.")
return model_url
Expand Down Expand Up @@ -163,7 +161,7 @@ def _segmentInstances(self, image: NDArray[np.uint8]) -> List[Any]:

return results

def _writeMaskedImage(self, tray_image: Image) -> None:
def _extractMaskedImage(self, tray_image: Image) -> Image:
""""""
[result] = tray_image.getSegmentationResults()
masks = result.masks.cpu()
Expand All @@ -179,9 +177,9 @@ def _writeMaskedImage(self, tray_image: Image) -> None:
hsv = [(i / num_instances, 1, brightness) for i in range(num_instances)]
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
random.shuffle(colors)
_, ax = plt.subplots()
for i in range(num_instances):
mask = masks.data[i].numpy()
(r, g, b) = colors[i]
for c in range(3):
result[:, :, c] = np.where(
mask == 1,
Expand All @@ -191,31 +189,21 @@ def _writeMaskedImage(self, tray_image: Image) -> None:

x1, y1, x2, y2 = coords[i]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
ax.text(
x1, y1 + 10, "{:.3f}".format(confs[i]), color="w", size=7, backgroundcolor="none"
)
p = patches.Rectangle(
cv2.rectangle(result, (x1, y1), (x2, y2), (r * 255, g * 255, b * 255), 5)
cv2.putText(
result,
"{:.3f}".format(confs[i]),
(x1, y1),
x2 - x1,
y2 - y1,
linewidth=1,
edgecolor=colors[i],
facecolor="none",
linestyle="dashed",
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=2,
color=(255, 255, 255),
thickness=3,
)
ax.add_patch(p)
plt.axis("off")
plt.imshow(result)
plt.tight_layout()
plt.savefig(
os.path.join(
self.full_images.getValue(),
tray_image.getImageName(),
),
bbox_inches="tight",
pad_inches=0,
dpi=300,
image_instance: Image = RGBImage(
pathlib.Path(tray_image.getImageName()).stem + f"_masked_image" + ".png"
)
image_instance.setImage(result)
return image_instance

def _sortInstances(self, boxes: NDArray[np.float32], img_shape: Tuple[int, int]):
"""
Expand Down Expand Up @@ -370,6 +358,7 @@ def performAnalysis(self) -> List[Image]:
# performs segmentation on each image one-by-one
segmented_images: List[Image] = []
tray_images: List[Image] = []
masked_images: List[Image] = []
for image_instance in self.images:
# set ImageIO with specific file path
self.image_io.setFilePath(image_instance.getFilePath())
Expand All @@ -389,22 +378,30 @@ def performAnalysis(self) -> List[Image]:
image_instance.setSegmentationResults(results=result)

try:
# extracts individual instances and tray information
# extracts individual instances
image_instances = self._extractImage(image_instance)
tray_info = self._extractTrayInfo(image_instance)
# and tray information
tray_infos = self._extractTrayInfo(image_instance)
# and masked image
masked_image = self._extractMaskedImage(image_instance)

# save to list for output
segmented_images.extend(image_instances)
tray_images.extend(tray_info)
# writes masked image
self._writeMaskedImage(image_instance)
tray_images.extend(tray_infos)
masked_images.append(masked_image)
except:
AttributeError("Error with the results.")

# 1. sets the output ImageListValue with the list of segmented images
# 2. writes the segmented images to "segmented_images" folder
# 3. writes the tray information to "tray_info" folder
# 4. writes the full masked images to "full_masked_images" folder
self.seg_images.setImageList(segmented_images)
self.seg_images.writeValue()

self.tray_infos.setImageList(tray_images)
self.tray_infos.writeValue()

self.masked_images.setImageList(masked_images)
self.masked_images.writeValue()
return segmented_images

0 comments on commit d26b839

Please sign in to comment.