improve non-cuda-ray mode
Browse files- main.py +11 -4
- nerf/sd.py +1 -1
- nerf/utils.py +15 -13
main.py
CHANGED
|
@@ -28,8 +28,8 @@ if __name__ == '__main__':
|
|
| 28 |
parser.add_argument('--ckpt', type=str, default='latest')
|
| 29 |
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
| 30 |
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
| 31 |
-
parser.add_argument('--num_steps', type=int, default=
|
| 32 |
-
parser.add_argument('--upsample_steps', type=int, default=
|
| 33 |
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
| 34 |
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
|
| 35 |
parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters that only use albedo shading")
|
|
@@ -40,8 +40,8 @@ if __name__ == '__main__':
|
|
| 40 |
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
| 41 |
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
|
| 42 |
# rendering resolution in training, decrease this if CUDA OOM.
|
| 43 |
-
parser.add_argument('--w', type=int, default=
|
| 44 |
-
parser.add_argument('--h', type=int, default=
|
| 45 |
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
|
| 46 |
|
| 47 |
### dataset options
|
|
@@ -55,6 +55,7 @@ if __name__ == '__main__':
|
|
| 55 |
parser.add_argument('--angle_front', type=float, default=30, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
|
| 56 |
|
| 57 |
parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
|
|
|
|
| 58 |
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
|
| 59 |
|
| 60 |
### GUI options
|
|
@@ -72,10 +73,16 @@ if __name__ == '__main__':
|
|
| 72 |
if opt.O:
|
| 73 |
opt.fp16 = True
|
| 74 |
opt.dir_text = True
|
|
|
|
| 75 |
opt.cuda_ray = True
|
|
|
|
|
|
|
|
|
|
| 76 |
elif opt.O2:
|
| 77 |
opt.fp16 = True
|
| 78 |
opt.dir_text = True
|
|
|
|
|
|
|
| 79 |
|
| 80 |
if opt.backbone == 'vanilla':
|
| 81 |
from nerf.network import NeRFNetwork
|
|
|
|
| 28 |
parser.add_argument('--ckpt', type=str, default='latest')
|
| 29 |
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
| 30 |
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
| 31 |
+
parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
|
| 32 |
+
parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
|
| 33 |
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
| 34 |
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
|
| 35 |
parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters that only use albedo shading")
|
|
|
|
| 40 |
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
| 41 |
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
|
| 42 |
# rendering resolution in training, decrease this if CUDA OOM.
|
| 43 |
+
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
|
| 44 |
+
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
|
| 45 |
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
|
| 46 |
|
| 47 |
### dataset options
|
|
|
|
| 55 |
parser.add_argument('--angle_front', type=float, default=30, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
|
| 56 |
|
| 57 |
parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
|
| 58 |
+
parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
|
| 59 |
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
|
| 60 |
|
| 61 |
### GUI options
|
|
|
|
| 73 |
if opt.O:
|
| 74 |
opt.fp16 = True
|
| 75 |
opt.dir_text = True
|
| 76 |
+
# use occupancy grid to prune ray sampling, faster rendering.
|
| 77 |
opt.cuda_ray = True
|
| 78 |
+
opt.lambda_entropy = 1e-4
|
| 79 |
+
opt.lambda_opacity = 0
|
| 80 |
+
|
| 81 |
elif opt.O2:
|
| 82 |
opt.fp16 = True
|
| 83 |
opt.dir_text = True
|
| 84 |
+
opt.lambda_entropy = 1e-3
|
| 85 |
+
opt.lambda_opacity = 1e-3 # no occupancy grid, so use a stronger opacity loss.
|
| 86 |
|
| 87 |
if opt.backbone == 'vanilla':
|
| 88 |
from nerf.network import NeRFNetwork
|
nerf/sd.py
CHANGED
|
@@ -20,7 +20,7 @@ class StableDiffusion(nn.Module):
|
|
| 20 |
print(f'[INFO] loaded hugging face access token from ./TOKEN!')
|
| 21 |
except FileNotFoundError as e:
|
| 22 |
self.token = True
|
| 23 |
-
print(f'[INFO] try to load hugging face access token from the default
|
| 24 |
|
| 25 |
self.device = device
|
| 26 |
self.num_train_timesteps = 1000
|
|
|
|
| 20 |
print(f'[INFO] loaded hugging face access token from ./TOKEN!')
|
| 21 |
except FileNotFoundError as e:
|
| 22 |
self.token = True
|
| 23 |
+
print(f'[INFO] try to load hugging face access token from the default place, make sure you have run `huggingface-cli login`.')
|
| 24 |
|
| 25 |
self.device = device
|
| 26 |
self.num_train_timesteps = 1000
|
nerf/utils.py
CHANGED
|
@@ -330,11 +330,11 @@ class Trainer(object):
|
|
| 330 |
if rand > 0.8:
|
| 331 |
shading = 'albedo'
|
| 332 |
ambient_ratio = 1.0
|
| 333 |
-
elif rand > 0.4:
|
| 334 |
-
|
| 335 |
-
|
| 336 |
else:
|
| 337 |
-
shading = '
|
| 338 |
ambient_ratio = 0.1
|
| 339 |
|
| 340 |
# _t = time.time()
|
|
@@ -355,22 +355,24 @@ class Trainer(object):
|
|
| 355 |
|
| 356 |
# encode pred_rgb to latents
|
| 357 |
# _t = time.time()
|
| 358 |
-
|
| 359 |
# torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')
|
| 360 |
|
| 361 |
# occupancy loss
|
| 362 |
pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
|
| 363 |
-
# mask_ws = outputs['mask'].reshape(B, 1, H, W) # near < far
|
| 364 |
|
| 365 |
-
|
|
|
|
|
|
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
|
|
|
| 372 |
|
| 373 |
-
if 'loss_orient' in outputs:
|
| 374 |
loss_orient = outputs['loss_orient']
|
| 375 |
loss = loss + self.opt.lambda_orient * loss_orient
|
| 376 |
|
|
|
|
| 330 |
if rand > 0.8:
|
| 331 |
shading = 'albedo'
|
| 332 |
ambient_ratio = 1.0
|
| 333 |
+
# elif rand > 0.4:
|
| 334 |
+
# shading = 'textureless'
|
| 335 |
+
# ambient_ratio = 0.1
|
| 336 |
else:
|
| 337 |
+
shading = 'lambertian'
|
| 338 |
ambient_ratio = 0.1
|
| 339 |
|
| 340 |
# _t = time.time()
|
|
|
|
| 355 |
|
| 356 |
# encode pred_rgb to latents
|
| 357 |
# _t = time.time()
|
| 358 |
+
loss = self.guidance.train_step(text_z, pred_rgb)
|
| 359 |
# torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')
|
| 360 |
|
| 361 |
# occupancy loss
|
| 362 |
pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
|
|
|
|
| 363 |
|
| 364 |
+
if self.opt.lambda_opacity > 0:
|
| 365 |
+
loss_opacity = (pred_ws ** 2).mean()
|
| 366 |
+
loss = loss + self.opt.lambda_opacity * loss_opacity
|
| 367 |
|
| 368 |
+
if self.opt.lambda_entropy > 0:
|
| 369 |
+
alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
|
| 370 |
+
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
|
| 371 |
+
loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
|
| 372 |
+
|
| 373 |
+
loss = loss + self.opt.lambda_entropy * loss_entropy
|
| 374 |
|
| 375 |
+
if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
|
| 376 |
loss_orient = outputs['loss_orient']
|
| 377 |
loss = loss + self.opt.lambda_orient * loss_orient
|
| 378 |
|