Skip to content

Commit

Permalink
adding torch plot sample
Browse files Browse the repository at this point in the history
  • Loading branch information
aamini committed Jan 5, 2025
1 parent b5b6895 commit 0afb904
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
39 changes: 32 additions & 7 deletions mitdeeplearning/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,46 @@ def display_model(model):
return ipythondisplay.Image("tmp.png")


def plot_sample(x, y, vae):
def plot_sample(x, y, vae, backend='tf'):
"""Plot original and reconstructed images side by side.
Args:
x: Input images array of shape [B, H, W, C] (TF) or [B, C, H, W] (PT)
y: Labels array of shape [B] where 1 indicates a face
vae: VAE model (TensorFlow or PyTorch)
framework: 'tf' or 'pt' indicating which framework to use
"""
plt.figure(figsize=(2, 1))
plt.subplot(1, 2, 1)

idx = np.where(y == 1)[0][0]
if backend == 'tf':
idx = np.where(y == 1)[0][0]
_, _, _, recon = vae(x)
recon = np.clip(recon, 0, 1)

elif backend == 'pt':
y = y.detach().cpu().numpy()
face_indices = np.where(y == 1)[0]
idx = face_indices[0] if len(face_indices) > 0 else 0

with torch.inference_mode():
_, _, _, recon = vae(x)
recon = torch.clamp(recon, 0, 1)
recon = recon.permute(0, 2, 3, 1).detach().cpu().numpy()
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()

else:
raise ValueError("framework must be 'tf' or 'pt'")

plt.subplot(1, 2, 1)
plt.imshow(x[idx])
plt.grid(False)

plt.subplot(1, 2, 2)
_, _, _, recon = vae(x)
recon = np.clip(recon, 0, 1)
plt.subplot(1, 2, 2)
plt.imshow(recon[idx])
plt.grid(False)

# plt.show()
if backend == 'pt':
plt.show()


class LossHistory:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def get_dist(pkgname):
setup(
name = 'mitdeeplearning', # How you named your package folder (MyLib)
packages = ['mitdeeplearning'], # Chose the same as "name"
version = '0.6.1', # Start with a small number and increase it with every change you make
version = '0.7.2', # Start with a small number and increase it with every change you make
license='MIT', # Chose a license from here: https://help.github.com/articles/licensing-a-repository
description = 'Official software labs for MIT Introduction to Deep Learning (http://introtodeeplearning.com)', # Give a short description about your library
author = 'Alexander Amini', # Type in your name
author_email = '[email protected]', # Type in your E-Mail
url = 'http://introtodeeplearning.com', # Provide either the link to your github or to your website
download_url = 'https://github.com/aamini/introtodeeplearning/archive/v0.6.1.tar.gz', # I explain this later on
download_url = 'https://github.com/aamini/introtodeeplearning/archive/v0.7.2.tar.gz', # I explain this later on
keywords = ['deep learning', 'neural networks', 'tensorflow', 'introduction'], # Keywords that define your package best
install_requires=install_deps,
classifiers=[
Expand Down

0 comments on commit 0afb904

Please sign in to comment.