Skip to content

Commit

Permalink
Merge pull request #35 from qwertpi/master
Browse files Browse the repository at this point in the history
Fixed display_heatmap
  • Loading branch information
Philippe Rémy authored Apr 17, 2019
2 parents e44e920 + 61d464a commit 3b3bf0e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions keract/keract.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def display_activations(activations, cmap=None, save=False):
plt.close(fig)


def display_heatmaps(activations, image, save=False):
def display_heatmaps(activations, input_image, save=False):
"""
Plot heatmaps of activations for all filters overlayed on the input image for each layer
:param activations: dict mapping layers to corresponding activations (1, output_h, output_w, num_filters)
:param image: input image for the overlay
:param input_image: input image for the overlay
:param save: bool- if the plot should be saved
:return: None
"""
Expand Down Expand Up @@ -155,11 +155,11 @@ def display_heatmaps(activations, image, save=False):
img = acts[0, :, :, i]
# scale the activations (which will form our heat map) to be in range 0-1
img = scaler.transform(img)
# resize heatmap to be same dimensions of image
# resize heatmap to be same dimensions of input_image
img = Image.fromarray(img)
img = img.resize((image.shape[0], image.shape[1]), Image.BILINEAR)
img = img.resize((input_image.shape[0], input_image.shape[1]), Image.BILINEAR)
img = np.array(img)
axes.flat[i].imshow(img / 255.0)
axes.flat[i].imshow(input_image / 255.0)
# overlay a 70% transparent heat map onto the image
# Lowest activations are dark, highest are dark red, mid are yellow
axes.flat[i].imshow(img, alpha=0.3, cmap='jet', interpolation='bilinear')
Expand Down

0 comments on commit 3b3bf0e

Please sign in to comment.