We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
`import torch import ttach as tta import timm import numpy as np from PIL import Image import matplotlib.pyplot as plt import cv2
model = torch.load('E:/PhD_Projects/egmentation models/new model weights/UNet_mitb2_thresh0.3.pth')
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="mean")
image_dir = 'E:/PhD_Projects/segmentation models/patches' image_filename_2 = 'image__02_02.tif' image_path = os.path.join(image_dir, image_filename_2) image = tiff.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
preprocessing_fn_inference = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) preprocessing_inference=get_preprocessing(preprocessing_fn_inference) sample = preprocessing_inference(image=image) image = sample['image']
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0) pr_mask = tta_model.predict(x_tensor) pr_mask = (pr_mask.squeeze().cpu().numpy().round()) pr_mask = (pr_mask.astype('float') * 255.0/16) #pr_mask = (pr_mask.astype('float') * 255.0/16).astype('uint8')
plt.imshow(pr_mask) plt.show()`
Can anyone help me with this prediction problem? Thank you. @qubvel
The text was updated successfully, but these errors were encountered:
No branches or pull requests
`import torch
import ttach as tta
import timm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
model = torch.load('E:/PhD_Projects/egmentation models/new model weights/UNet_mitb2_thresh0.3.pth')
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="mean")
image_dir = 'E:/PhD_Projects/segmentation models/patches'
image_filename_2 = 'image__02_02.tif'
image_path = os.path.join(image_dir, image_filename_2)
image = tiff.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
preprocessing_fn_inference = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
preprocessing_inference=get_preprocessing(preprocessing_fn_inference)
sample = preprocessing_inference(image=image)
image = sample['image']
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
pr_mask = tta_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
pr_mask = (pr_mask.astype('float') * 255.0/16)
#pr_mask = (pr_mask.astype('float') * 255.0/16).astype('uint8')
=============================================================================
plt.imshow(pr_mask)
plt.show()`
Can anyone help me with this prediction problem? Thank you. @qubvel
The text was updated successfully, but these errors were encountered: