Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The resulting images are of poor quality #316

Open
xiaomei2002 opened this issue Apr 15, 2024 · 4 comments
Open

The resulting images are of poor quality #316

xiaomei2002 opened this issue Apr 15, 2024 · 4 comments

Comments

@xiaomei2002
Copy link

xiaomei2002 commented Apr 15, 2024

first,i tried to train the CUB200 dataset, but the picture is blurry and there is no bird shape.
so i tried to train only one sample, the result was not good either.

@xiaomei2002
Copy link
Author

I used the following code to train this sunflower picture and its label "flower", the final prompt was also entered "flower", training 1000 rounds, but the result was not satisfactory, could someone please help me to look at the problem?

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, \
    OpenAIClipAdapter, DecoderTrainer
from typing import List, Dict
from torchvision.utils import make_grid
import torchvision.transforms as T
from torchvision.utils import save_image
from PIL import Image
from datetime import datetime
import os
import torch.utils.data as data
import json
# openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter()

# # mock data
#
# text = torch.randint(0, 49408, (4, 256)).cuda()
# images = torch.randn(4, 3, 256, 256).cuda()

def read_metadata(text: str) -> List[Dict]:
    data = []
    for line in text.split('\n'):
        if not line:
            continue
        line_json = json.loads(line)
        data.append(line_json)
    return data

class ImgTextDataset(data.Dataset):
    def __init__(self, fp: str):
        self.fp = fp
        with open(fp, 'r') as file:
            metadata = read_metadata(file.read())

        self.img_paths = []
        self.captions = []

        for line in metadata:
            self.img_paths.append(line['url'])
            self.captions.append(line['caption'])

        # Make sure that each image is captioned
        assert len(self.img_paths) == len(self.captions)
        # Apply required image transforms. For my model I need RGB images with 256 x 256 dimensions.
        self.image_tranform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((256, 256)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image_path = os.path.join(self.fp, self.img_paths[idx])
        caption = self.captions[idx]

        image = Image.open(image_path)

        image_pt = self.image_tranform(image).cuda()
        return image_pt, caption

image_size = 256  # Image dimension
batch_size = 1  # Batch size for training, adjust based on GPU memory
learning_rate = 3e-4  # Learning rate for the optimizer
num_epochs = 1000 # Number of epochs for training
log_image_interval = 100 # Interval for logging images
save_dir = "./log_images"  # Directory to save log images
os.makedirs(save_dir, exist_ok=True)  # Create save directory if it doesn't exist
# prior networks (with transformer)

device = torch.device("cuda")  # Not recommended to train on cpu

# Define your image-text dataset
dataset = ImgTextDataset('/root/autodl-tmp/dalle2-train/data/output2.jsonl')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

t = SimpleTokenizer()

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

diffusion_prior_trainer = DiffusionPriorTrainer(
    diffusion_prior,
    lr = 3e-4,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)
for epoch in range(num_epochs):
    for prior_batch_idx, (images, texts) in enumerate(dataloader):
        loss = diffusion_prior_trainer(
            image=images.cuda(),
            text=t.tokenize(texts).cuda(),
            max_batch_size=4
        )
        diffusion_prior_trainer.update()
        if prior_batch_idx % 500 == 0:
            print(f"prior epoch {epoch}, step {prior_batch_idx}, loss {loss}")
#torch.save(diffusion_prior_trainer.state_dict(), f'model_{epoch}.pth')

# do above for many steps ...

# decoder (with unet)
unet1 = Unet(
    dim=128,
    image_embed_dim=512,
    text_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings=True,
).cuda()

decoder = Decoder(
    unet=unet1,
    image_size=image_size,
    clip=clip,
    timesteps=1000
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr=5e-4,
    wd=1e-2,
    ema_beta=0.99,
    ema_update_after_step=1000,
    ema_update_every=10,
).cuda()



# Training loop.
# Iterate over the dataloader and pass image tensors and tokenized text to the training wrapper.
# Repeat process N times.

for epoch in range(num_epochs):
    for batch_idx, (images, texts) in enumerate(dataloader):
        loss = decoder_trainer(
            images.cuda(),
            text=t.tokenize(texts).cuda(),
            unet_number=1,
            max_batch_size=4
        )
        decoder_trainer.update(1)
        if batch_idx % 500 == 0:
            print(f"decoder epoch {epoch}, step {batch_idx}, loss {loss}")
    # if (epoch+1) % log_image_interval == 0 :
    #     image_embed = clip.embed_image(images.cuda())
    #     sample = decoder_trainer.sample(image_embed=image_embed[0], text=t.tokenize(texts).cuda())
    #     save_image(sample, f'./log_images/{epoch}_{batch_idx}.png')
    # Periodically save the model.
#torch.save(decoder_trainer.state_dict(), f'model_{epoch}.pth')



dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
).cuda()

gen_images = dalle2(
    ['flower'],
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
).cuda()

print(gen_images.shape)

from torchvision.utils import save_image

save_image(gen_images, 'image.png')

for img in gen_images:
    img = ToPILImage()(img)
    img.show()

the data for training
flower
the generate image
generate
Could it be my hyperparameter Settings?
hope to receive reply ! thanks !!!!! ^ ^

@xiaomei2002 xiaomei2002 changed the title Could someone provide a written training code? The resulting images are of poor quality Apr 16, 2024
@Siddharth-Latthe-07
Copy link

Try out using Stable Diffusion model, it is more powerful than DALLE-2, due to its noising and denoising feature, the clarity of the images are good, and also u can use nvidia GPU
Can refer this:-
https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb

@chenrxi
Copy link

chenrxi commented Aug 11, 2024

How about your final training loss? I think the reason is related to training from scratch, because the dataset scale of cub200 is limited, (maybe 6000 in total)? @xiaomei2002

@pankil25
Copy link

@xiaomei2002 are you able to solve it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants