StableMTL: Repurposing Latent Diffusion Models for Multi-Task Learning from Partially Annotated Synthetic Datasets
Anh-Quan Cao1 Ivan Lopes2 Raoul de Charette2
If you find this work or code useful, please cite our paper and give this repo a star ⭐:
@InProceedings{stablemtl,
title = {StableMTL: Repurposing Latent Diffusion Models for Multi-Task Learning from Partially Annotated Synthetic Datasets},
author = {Anh-Quan Cao and Ivan Lopes and Raoul de Charette},
year = {2025},
booktitle = {arXiv}
}
- 1. 📦 Installation
- 2. ⚙️ Environment Variables Setup
- 3. 🗄️ Datasets
- 5. 📊 Evaluation
- 6. 🔬 Training
- 7. 🙏 Acknowledgement
-
Clone the repository
git clone https://github.com/astra-vision/StableMTL.git cd StableMTL
-
Install PyTorch 2.3.1
# Install appropriate version for your CUDA setup pip install torch==2.3.1 torchvision==0.18.1
-
Install dependencies
pip install --no-cache-dir -r requirements.txt
-
Configure environment variables
source env.sh
Before proceeding with dataset downloads or model training/evaluation, you need to configure your environment.
-
Define Paths in
env.sh
: Open theenv.sh
file in the root of this repository. Modify or add the following environment variables to point to your desired directories:CODE_DIR
: Path to the StableMTL source code (e.g.,$(pwd)
if running from the repo root).RAW_DATA_DIR
: Path where raw datasets will be downloaded.PREPROCESSED_DIR
: Path where preprocessed datasets will be stored.BASE_CKPT_DIR
: Path where base model checkpoints (like Stable Diffusion) will be stored.OUTPUT_DIR
: Path where training outputs and resulting checkpoints will be saved.
-
Source
env.sh
: After saving your changes toenv.sh
, source it in your terminal session:source env.sh
You only need to do this once per terminal session, or whenever you open a new terminal or modify
env.sh
. All subsequent commands in this README assume these variables are correctly set.
All dataset download and preprocessing scripts expect the environment variables (CODE_DIR
, RAW_DATA_DIR
, PREPROCESSED_DIR
) to be set as described above. Make sure you have sourced env.sh
before running any scripts.
-
Download Hypersim dataset
-
Download Hypersim dataset using this script.
-
Download the scene split file from here.
-
Hypersim dataset should be placed in
$RAW_DATA_DIR/Hypersim/evermotion_dataset/scenes
-
Scene split file should be placed in
$CODE_DIR/data_split/hypersim/metadata_images_split_scene_v1.csv
-
-
Preprocess Hypersim dataset
python $CODE_DIR/dataset_preprocess/hypersim/preprocess_hypersim.py \ --output_dir $PREPROCESSED_DIR/hypersim \ --split_csv $CODE_DIR/data_split/hypersim/metadata_images_split_scene_v1.csv \ --dataset_dir $RAW_DATA_DIR/Hypersim/evermotion_dataset/scenes
-
Download vkitti to
$RAW_DATA_DIR/VirtualKitti2
wget http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_rgb.tar wget http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_depth.tar wget http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_forwardFlow.tar wget http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_classSegmentation.tar wget http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_forwardSceneFlow.tar
-
Create a symlink from the original dataset to the preprocessed directory
ln -s $RAW_DATA_DIR/VirtualKitti2 $PREPROCESSED_DIR/vkitti_v2/original
-
Generate normal maps
cd $CODE_DIR/depth-to-normal-translator/python python process_vkitti2.py --output_root $PREPROCESSED_DIR/vkitti_v2
-
Create directory and download dataset
source env.sh mkdir $RAW_DATA_DIR/FlyingThings3D cd $RAW_DATA_DIR/FlyingThings3D bash $CODE_DIR/dataset_preprocess/flying_things_3D/download_flying3d.sh
-
Download dataset components
source env.sh cd $RAW_DATA_DIR/FlyingThings3D wget --no-check-certificate https://lmb.informatik.uni-freiburg.de/data/FlyingThings3D_subset/FlyingThings3D_subset_image_clean.tar.bz2 wget --no-check-certificate https://lmb.informatik.uni-freiburg.de/data/FlyingThings3D_subset/FlyingThings3D_subset_flow.tar.bz2 wget --no-check-certificate https://lmb.informatik.uni-freiburg.de/data/FlyingThings3D_subset/FlyingThings3D_subset_disparity.tar.bz2 wget --no-check-certificate https://lmb.informatik.uni-freiburg.de/data/FlyingThings3D_subset/FlyingThings3D_subset_disparity_change.tar.bz2 wget --no-check-certificate https://lmb.informatik.uni-freiburg.de/data/FlyingThings3D_subset/FlyingThings3D_subset_flow_occlusions.tar.bz2
-
Extract the dataset
source env.sh tar -xvf FlyingThings3D_subset_image_clean.tar.bz2 -C $RAW_DATA_DIR tar -xvf FlyingThings3D_subset_flow.tar.bz2 -C $RAW_DATA_DIR tar -xvf FlyingThings3D_subset_disparity.tar.bz2 -C $RAW_DATA_DIR tar -xvf FlyingThings3D_subset_disparity_change.tar.bz2 -C $RAW_DATA_DIR tar -xvf FlyingThings3D_subset_flow_occlusions.tar.bz2 -C $RAW_DATA_DIR
-
Preprocess FlyingThings3D dataset
source env.sh python $CODE_DIR/dataset_preprocess/flying_things_3D/preprocess.py \ --input_dir $RAW_DATA_DIR/FlyingThings3D_subset \ --output_dir $PREPROCESSED_DIR/FlyingThings3D_preprocessed
-
Download Cityscapes dataset
Remember to replace
myusername
andmypassword
with your Cityscapes account credentials.wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=myusername&password=mypassword&submit=Login' https://www.cityscapes-dataset.com/login/ wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 # gtFine_trainvaltest.zip wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 # leftImg8bit_trainvaltest.zip
-
Extract the dataset
source env.sh mkdir -p $PREPROCESSED_DIR/cityscapes unzip gtFine_trainvaltest.zip -d $PREPROCESSED_DIR/cityscapes unzip leftImg8bit_trainvaltest.zip -d $PREPROCESSED_DIR/cityscapes
-
Download KITTI flow 2015 dataset
source env.sh mkdir -p $PREPROCESSED_DIR/kitti/flow_2015 cd $PREPROCESSED_DIR/kitti/flow_2015 wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow.zip wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow_calib.zip
-
Extract the dataset
source env.sh unzip data_scene_flow.zip -d $PREPROCESSED_DIR/kitti/flow_2015 unzip data_scene_flow_calib.zip -d $PREPROCESSED_DIR/kitti/flow_2015
-
Download KITTI eigen test split
source env.sh wget https://share.phys.ethz.ch/~pf/bingkedata/marigold/evaluation_dataset/kitti/kitti_eigen_split_test.tar
-
Extract the dataset
source env.sh mkdir -p $PREPROCESSED_DIR/kitti/kitti_eigen_split_test tar -xvf kitti_eigen_split_test.tar -C $PREPROCESSED_DIR/kitti/kitti_eigen_split_test
-
Download DIODE validation depth and normal dataset
wget http://diode-dataset.s3.amazonaws.com/val.tar.gz wget http://diode-dataset.s3.amazonaws.com/val_normals.tar.gz
-
Extract and organize the dataset
source env.sh mkdir -p $PREPROCESSED_DIR/diode tar -xvf val.tar.gz -C $PREPROCESSED_DIR/diode tar -xvf val_normals.tar.gz -C $PREPROCESSED_DIR/diode mv $PREPROCESSED_DIR/diode/val $PREPROCESSED_DIR/diode/diode_val
-
Download MID Intrinsic dataset and albedo ground truth
source env.sh wget https://data.csail.mit.edu/multilum/multi_illumination_test_mip2_exr.zip wget https://huggingface.co/anhquancao/StableMTL/resolve/main/midi_test_albedo.zip
-
Extract the dataset
source env.sh mkdir -p $RAW_DATA_DIR/mid_intrinsics/test unzip multi_illumination_test_mip2_exr.zip -d $RAW_DATA_DIR/mid_intrinsics/test unzip midi_test_albedo.zip -d $RAW_DATA_DIR/mid_intrinsics/test
-
Preprocess the dataset
source env.sh python $CODE_DIR/dataset_preprocess/mid_intrinsics/preprocess.py \ --input_dir $RAW_DATA_DIR/mid_intrinsics \ --output_dir $PREPROCESSED_DIR/mid_intrinsics
After downloading and preprocessing, your dataset structure should be organized as follows:
$PREPROCESSED_DIR/
├── cityscapes/ # Cityscapes dataset
├── diode/ # DIODE depth and normal dataset
│ └── diode_val/
├── FlyingThings3D_preprocessed/ # FlyingThings3D dataset
├── hypersim/ # Hypersim dataset
│ └── train/
├── kitti/ # KITTI datasets
│ ├── flow_2015/ # KITTI flow dataset
│ └── kitti_eigen_split_test/ # KITTI depth dataset
├── mid_intrinsics/ # MID Intrinsic dataset
└── vkitti_v2/ # Virtual KITTI dataset
└── original/ # Symlink to original vkitti data
Note: Make sure all datasets are properly downloaded and preprocessed before proceeding to the evaluation or training steps.
-
Download and extract checkpoint
source env.sh wget https://huggingface.co/anhquancao/StableMTL/resolve/main/single_stream.tar.gz -O $OUTPUT_DIR/single_stream.tar.gz tar -xvf $OUTPUT_DIR/single_stream.tar.gz -C $OUTPUT_DIR
-
Run evaluation
source env.sh python eval_mtl.py --config config/dataset/dataset_test.yaml \ --resume_run=$OUTPUT_DIR/single_stream/checkpoint/latest \ --base_ckpt_dir=$BASE_CKPT_DIR --base_data_dir=$PREPROCESSED_DIR
This will evaluate the single-stream model on all test datasets configured in the dataset_test.yaml file.
-
Download and extract checkpoint
source env.sh wget https://huggingface.co/anhquancao/StableMTL/resolve/main/multi_stream.tar.gz -O $OUTPUT_DIR/multi_stream.tar.gz tar -xvf $OUTPUT_DIR/multi_stream.tar.gz -C $OUTPUT_DIR
-
Download pre-trained single-stream UNet checkpoint
The multi-stream model uses components from the single-stream UNet
source env.sh mkdir -p $CODE_DIR/checkpoint wget https://huggingface.co/anhquancao/StableMTL/resolve/main/single_stream_unet.pth -O $CODE_DIR/checkpoint/single_stream_unet.pth
-
Run evaluation
source env.sh python eval_mtl.py --config config/dataset/dataset_test.yaml \ --resume_run=$OUTPUT_DIR/multi_stream/checkpoint/latest \ --base_ckpt_dir=$BASE_CKPT_DIR --base_data_dir=$PREPROCESSED_DIR
This will evaluate the multi-stream model on all test datasets configured in the dataset_test.yaml file.
This is needed to train the single-stream model
-
Create base checkpoint directory
source env.sh mkdir -p $BASE_CKPT_DIR
-
Download checkpoint from Hugging Face
from huggingface_hub import snapshot_download snapshot_download(repo_id="stabilityai/stable-diffusion-2")
-
Create a symlink from the original checkpoint to the base checkpoint directory
ln -s ~/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2/snapshots/1e128c8891e52218b74cde8f26dbfc701cb99d79 $BASE_CKPT_DIR/stable-diffusion-2
This is needed to train the multi-stream model
Download pre-trained single-stream checkpoint
source env.sh
wget https://huggingface.co/anhquancao/StableMTL/resolve/main/single_stream_unet.pth -O $BASE_CKPT_DIR/single_stream_unet.pth
Train the single-stream model
source env.sh
python train_stablemtl.py --config config/train_stablemtl_s.yaml \
--subfix=SingleStream \
--output_dir=$OUTPUT_DIR \
--base_ckpt_dir=$BASE_CKPT_DIR --base_data_dir=$PREPROCESSED_DIR
Make sure you have downloaded the pre-trained single-stream checkpoint as described in section 6.1.2
source env.sh
python train_stablemtl.py --config config/train_stablemtl.yaml \
--subfix=MultiStream \
--output_dir=$OUTPUT_DIR \
--base_ckpt_dir=$BASE_CKPT_DIR --base_data_dir=$PREPROCESSED_DIR
source env.sh
accelerate launch --config_file config/accelerator/multigpus_2_fp32.yaml \
--main_process_port 29512 \
train_stablemtl.py --config config/train_stablemtl.yaml \
--n_gpus=2 \
--subfix=MultiStream \
--output_dir=$OUTPUT_DIR \
--base_ckpt_dir=$BASE_CKPT_DIR --base_data_dir=$PREPROCESSED_DIR
This code builds upon the following excellent open-source repositories:
This work was supported by the French Agence Nationale de la Recherche (ANR) under project SIGHT (ANR-20-CE23-0016). Computations were performed using HPC resources from GENCI-IDRIS (Grants AD011014102R2, AD011014102R1, AD011014389R1, and AD011012808R3). The authors also thank the CLEPS infrastructure at Inria Paris for additional support.