How do we train lora for Flux-Fill-Dev? #1180
-
The newly released Flux-Fill-Dev (Inpainting/Outpainting) is incredible. However, the LoRA I previously trained on Flux-Dev doesn't work directly with Flux-Fill-Dev, so I'm planning to retrain a LoRA specifically for Flux-Fill-Dev. From my quick analysis, Flux-Fill-Dev takes an input of a 16-channel image latent, a 16-channel masked image latent, and a 64-channel mask. My initial idea for input preparation is to use the 16-channel image latent repeated twice, along with a 64-channel mask filled entirely with zeros (black mask). Any thoughts or suggestions? |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 16 replies
-
any new update for your training? i am also trying to train flux-fill-dev with lora. |
Beta Was this translation helpful? Give feedback.
-
你好,我最近也在用flux-fill-dev作为基座模型来训练,数据输入的channel也是16+16+64,但是训练出来的结果很差,想问一下你训练得怎么样了 |
Beta Was this translation helpful? Give feedback.
-
Hello, I have recently been using flux-fill-dev as the base model for training. The data input channel is also 16+16+64, but the training results are very poor. I would like to ask about your training results. |
Beta Was this translation helpful? Give feedback.
-
Here's my training script. I used just one instance prompt describing the inpaining task, instead of a caption for each sample. The LoRa adapted my concept with just a couple of hundred steps. But there are some known issues: like the validation is broken and also the masking is hardcoded, but I hope it helps someone to adapt to their use cases :) |
Beta Was this translation helpful? Give feedback.
What I did is really simple and raw, as I discussed:
`def pack_fill_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
# Step 2: Repeat the packed latents
repeated_latents = latents.repeat(1, 1, 2) # Repeat along the channel dimension