Skip to content

Commit

Permalink
change callback_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
Dazhi Zhong committed Jan 6, 2022
1 parent 40f4bd2 commit 0a3119f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 32 deletions.
20 changes: 6 additions & 14 deletions cfg_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def resize_and_center_crop(image, size):
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
return TF.center_crop(image, size[::-1])

def callback_fn(info):
if info['i'] % 50==0:
out = info['pred'].add(1).div(2)
save_image(out, f"interm_output_{info['i']:05d}.png")
if IS_NOTEBOOK:
display.display(display.Image(f"interm_output_{info['i']:05d}.png",height=300))

def main():
p = argparse.ArgumentParser(description=__doc__,
Expand Down Expand Up @@ -132,13 +138,6 @@ def main():

torch.manual_seed(args.seed)

def callback_fn(pred, i):
if i % 50==0 or i==args.steps:
out = pred.add(1).div(2)
save_image(out, f"interm_output_{i:05d}.png")
if IS_NOTEBOOK:
display.display(display.Image(f"interm_output_{i:05d}.png",height=300))

def cfg_model_fn(x, t):
n = x.shape[0]
n_conds = len(target_embeds)
Expand Down Expand Up @@ -235,13 +234,6 @@ def run_diffusion_cfg(prompts,images=None,steps=1000,init=None,model="cc12m_1_cf

torch.manual_seed(args.seed)

def callback_fn(pred, i):
if i % display_freq==0 or i==args.steps:
out = pred.add(1).div(2)
save_image(out, f"interm_output_{i:05d}.png")
if IS_NOTEBOOK:
display.display(display.Image(f"interm_output_{i:05d}.png",height=300))

def cfg_model_fn(x, t):
n = x.shape[0]
n_conds = len(target_embeds)
Expand Down
20 changes: 6 additions & 14 deletions clip_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def resize_and_center_crop(image, size):
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
return TF.center_crop(image, size[::-1])

def callback_fn(info):
if info['i'] % 50==0:
out = info['pred'].add(1).div(2)
save_image(out, f"interm_output_{info['i']:05d}.png")
if IS_NOTEBOOK:
display.display(display.Image(f"interm_output_{info['i']:05d}.png",height=300))

def main():
p = argparse.ArgumentParser(description=__doc__,
Expand Down Expand Up @@ -176,13 +182,6 @@ def main():

torch.manual_seed(args.seed)

def callback_fn(pred, i):
if i % 50==0 or i==args.steps:
out = pred.add(1).div(2)
save_image(out, f"interm_output_{i:05d}.png")
if IS_NOTEBOOK:
display.display(display.Image(f"interm_output_{i:05d}.png",height=300))

def cond_fn(x, t, pred, clip_embed):
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
Expand Down Expand Up @@ -295,13 +294,6 @@ def run_diffusion(prompts,images=None,steps=1000,init=None,model="yfcc_2",size=[

torch.manual_seed(args.seed)

def callback_fn(pred, i):
if i % display_freq==0 or i==args.steps:
out = pred.add(1).div(2)
save_image(out, f"interm_output_{i:05d}.png")
if IS_NOTEBOOK:
display.display(display.Image(f"interm_output_{i:05d}.png",height=300))

def cond_fn(x, t, pred, clip_embed):
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
Expand Down
4 changes: 0 additions & 4 deletions diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def sample(model, x, steps, eta, extra_args, callback=None):
if eta:
x += torch.randn_like(x) * ddim_sigma

if callback_fn:
callback_fn(pred,i)

# If we are on the last timestep, output the denoised image
return pred
Expand Down Expand Up @@ -101,8 +99,6 @@ def cond_sample(model, x, steps, eta, extra_args, cond_fn, callback=None):
if eta:
x += torch.randn_like(x) * ddim_sigma

if callback_fn:
callback_fn(pred,i)

# If we are on the last timestep, output the denoised image
return pred

0 comments on commit 0a3119f

Please sign in to comment.