Skip to content

Commit

Permalink
Training Code (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
GonzaloMartinGarcia authored Oct 23, 2024
1 parent 0da65fb commit 55e511f
Show file tree
Hide file tree
Showing 32 changed files with 2,940 additions and 42 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
.venv/
__pycache__/
.DS_Store
wandb/
.vscode/

/data
/output
/input
/model-finetuned

experiments/depth/*
experiments/normals/*
!experiments/depth/eval_args/
!experiments/normals/eval_args/

metadata_images_split_scene_v1.csv
19 changes: 16 additions & 3 deletions GeoWizard/geowizard/models/geowizard_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Adapted from Marigold :https://github.com/prs-eth/Marigold

# @GonzaloMartinGarcia, all new additions to the GeoWizard code have been marked with # add.
# @GonzaloMartinGarcia
# All new additions to the GeoWizard code have been marked with # add.

from typing import Any, Dict, Union

Expand All @@ -17,7 +18,6 @@
)
from ..models.unet_2d_condition import UNet2DConditionModel
from diffusers.utils import BaseOutput
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
Expand All @@ -27,7 +27,20 @@
from ..utils.image_util import resize_max_res, chw2hwc, colorize_depth_maps
from ..utils.depth_ensemble import ensemble_depths
from ..utils.normal_ensemble import ensemble_normals
from ..utils.noise import pyramid_noise_like

# add
# Pyramid noise from GeoWizard training code.
def pyramid_noise_like(x, timesteps, discount=0.9):
b, c, w_ori, h_ori = x.shape
u = nn.Upsample(size=(w_ori, h_ori), mode='bilinear')
noise = torch.randn_like(x)
scale = 1.5
for i in range(10):
r = np.random.random()*scale + scale # Rather than always going 2x,
w, h = max(1, int(w_ori/(r**i))), max(1, int(h_ori/(r**i)))
noise += u(torch.randn(b, c, w, h).to(x)) * (timesteps[...,None,None,None]/1000) * discount**i
if w==1 or h==1: break # Lowest resolution is 1x1
return noise/noise.std() # Scaled back to roughly unit variance


class DepthNormalPipelineOutput(BaseOutput):
Expand Down
Empty file.
Loading

0 comments on commit 55e511f

Please sign in to comment.