diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 6b422b57955f..a1ea73f0d2c2 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -15,6 +15,8 @@ """Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" import argparse +import itertools +import json import logging import math import os @@ -195,6 +197,44 @@ def parse_args(): action="store_true", help="whether to randomly flip images horizontally", ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + + # lora args + parser.add_argument("--use_peft", action="store_true", help="Whether to use peft to support lora") + parser.add_argument("--lora_r", type=int, default=4, help="Lora rank, only used if use_lora is True") + parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if lora is True") + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True") + parser.add_argument( + "--lora_bias", + type=str, + default="none", + help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True", + ) + parser.add_argument( + "--lora_text_encoder_r", + type=int, + default=4, + help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_alpha", + type=int, + default=32, + help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_dropout", + type=float, + default=0.0, + help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_bias", + type=str, + default="none", + help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", + ) + parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) @@ -429,11 +469,6 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - # freeze parameters of models to save more memory - unet.requires_grad_(False) - vae.requires_grad_(False) - - text_encoder.requires_grad_(False) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. @@ -443,43 +478,79 @@ def main(): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + if args.use_peft: + + from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict + + UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] + TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=UNET_TARGET_MODULES, + lora_dropout=args.lora_dropout, + bias=args.lora_bias, + ) + unet = LoraModel(config, unet) + + vae.requires_grad_(False) + if args.train_text_encoder: + + config = LoraConfig( + r=args.lora_text_encoder_r, + lora_alpha=args.lora_text_encoder_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + lora_dropout=args.lora_text_encoder_dropout, + bias=args.lora_text_encoder_bias, + ) + text_encoder = LoraModel(config, text_encoder) + else: + + # freeze parameters of models to save more memory + unet.requires_grad_(False) + vae.requires_grad_(False) + + text_encoder.requires_grad_(False) + + # now we will add new LoRA weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(unet.attn_processors) + # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) - # now we will add new LoRA weights to the attention layers - # It's important to realize here how many attention weights will be added and of which sizes - # The sizes of the attention layers consist only of two different variables: - # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. - # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. - - # Let's first see how many attention processors we will have to set. - # For Stable Diffusion, it should be equal to: - # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 - # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 - # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 - # => 32 layers - - # Set correct lora layers - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) - - unet.set_attn_processor(lora_attn_procs) - if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -493,8 +564,6 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - lora_layers = AttnProcsLayers(unet.attn_processors) - # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -518,13 +587,28 @@ def main(): else: optimizer_cls = torch.optim.AdamW - optimizer = optimizer_cls( - lora_layers.parameters(), - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) + if args.peft: + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + optimizer = optimizer_cls( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + else: + optimizer = optimizer_cls( + lora_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -645,9 +729,19 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - lora_layers, optimizer, train_dataloader, lr_scheduler - ) + if args.peft: + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + else: + lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + lora_layers, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -705,6 +799,8 @@ def collate_fn(examples): for epoch in range(first_epoch, args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step @@ -751,7 +847,14 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = lora_layers.parameters() + if args.peft: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + else: + params_to_clip = lora_layers.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -786,6 +889,7 @@ def collate_fn(examples): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), revision=args.revision, torch_dtype=weight_dtype, ) @@ -821,8 +925,24 @@ def collate_fn(examples): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = unet.to(torch.float32) - unet.save_attn_procs(args.output_dir) + if args.use_peft: + lora_config = {} + state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) + lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + if args.train_text_encoder: + text_encoder_state_dict = get_peft_model_state_dict( + text_encoder, state_dict=accelerator.get_state_dict(text_encoder) + ) + text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} + state_dict.update(text_encoder_state_dict) + lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + + accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt")) + with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f: + json.dump(lora_config, f) + else: + unet = unet.to(torch.float32) + unet.save_attn_procs(args.output_dir) if args.push_to_hub: save_model_card( @@ -839,10 +959,45 @@ def collate_fn(examples): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype ) - pipeline = pipeline.to(accelerator.device) - # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) + if args.use_peft: + + def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): + with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f: + lora_config = json.load(f) + print(lora_config) + + checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt" + lora_checkpoint_sd = torch.load(checkpoint) + unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} + text_encoder_lora_ds = { + k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k + } + + unet_config = LoraConfig(**lora_config["peft_config"]) + pipe.unet = LoraModel(unet_config, pipe.unet) + set_peft_model_state_dict(pipe.unet, unet_lora_ds) + + if "text_encoder_peft_config" in lora_config: + text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"]) + pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder) + set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds) + + if dtype in (torch.float16, torch.bfloat16): + pipe.unet.half() + pipe.text_encoder.half() + + pipe.to(device) + return pipe + + pipeline = load_and_set_lora_ckpt( + pipeline, args.output_dir, args.instance_prompt, accelerator.device, weight_dtype + ) + + else: + pipeline = pipeline.to(accelerator.device) + # load attention processors + pipeline.unet.load_attn_procs(args.output_dir) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)