forked from electech6/NeRF-Based-SLAM-Incredible-Insights
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
222 lines (167 loc) · 8.79 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
import torch
import numpy as np
import trimesh
import marching_cubes as mcubes
from matplotlib import pyplot as plt
#### GO-Surf ####
def coordinates(voxel_dim, device: torch.device, flatten=True):
if type(voxel_dim) is int:
nx = ny = nz = voxel_dim
else:
nx, ny, nz = voxel_dim[0], voxel_dim[1], voxel_dim[2]
x = torch.arange(0, nx, dtype=torch.long, device=device)
y = torch.arange(0, ny, dtype=torch.long, device=device)
z = torch.arange(0, nz, dtype=torch.long, device=device)
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
if not flatten:
return torch.stack([x, y, z], dim=-1)
return torch.stack((x.flatten(), y.flatten(), z.flatten()))
#### ####
def getVoxels(x_max, x_min, y_max, y_min, z_max, z_min, voxel_size=None, resolution=None):
if not isinstance(x_max, float):
x_max = float(x_max)
x_min = float(x_min)
y_max = float(y_max)
y_min = float(y_min)
z_max = float(z_max)
z_min = float(z_min)
if voxel_size is not None:
Nx = round((x_max - x_min) / voxel_size + 0.0005)
Ny = round((y_max - y_min) / voxel_size + 0.0005)
Nz = round((z_max - z_min) / voxel_size + 0.0005)
tx = torch.linspace(x_min, x_max, Nx + 1)
ty = torch.linspace(y_min, y_max, Ny + 1)
tz = torch.linspace(z_min, z_max, Nz + 1)
else:
tx = torch.linspace(x_min, x_max, resolution)
ty = torch.linspace(y_min, y_max,resolution)
tz = torch.linspace(z_min, z_max, resolution)
return tx, ty, tz
def get_batch_query_fn(query_fn, num_args=1, device=None):
if num_args == 1:
fn = lambda f, i0, i1: query_fn(f[i0:i1, None, :].to(device))
else:
fn = lambda f, f1, i0, i1: query_fn(f[i0:i1, None, :].to(device), f1[i0:i1, :].to(device))
return fn
#### NeuralRGBD ####
@torch.no_grad()
def extract_mesh(query_fn, config, bounding_box, marching_cube_bound=None, color_func = None, voxel_size=None, resolution=None, isolevel=0.0, scene_name='', mesh_savepath=''):
'''
Extracts mesh from the scene model using marching cubes (Adapted from NeuralRGBD)
'''
# Query network on dense 3d grid of points
if marching_cube_bound is None:
marching_cube_bound = bounding_box
x_min, y_min, z_min = marching_cube_bound[:, 0]
x_max, y_max, z_max = marching_cube_bound[:, 1]
tx, ty, tz = getVoxels(x_max, x_min, y_max, y_min, z_max, z_min, voxel_size, resolution)
query_pts = torch.stack(torch.meshgrid(tx, ty, tz, indexing='ij'), -1).to(torch.float32)
sh = query_pts.shape
flat = query_pts.reshape([-1, 3])
bounding_box_cpu = bounding_box.cpu()
if config['grid']['tcnn_encoding']:
flat = (flat - bounding_box_cpu[:, 0]) / (bounding_box_cpu[:, 1] - bounding_box_cpu[:, 0])
fn = get_batch_query_fn(query_fn, device=bounding_box.device)
chunk = 1024 * 64
raw = [fn(flat, i, i + chunk).cpu().data.numpy() for i in range(0, flat.shape[0], chunk)]
raw = np.concatenate(raw, 0).astype(np.float32)
raw = np.reshape(raw, list(sh[:-1]) + [-1])
print('Running Marching Cubes')
vertices, triangles = mcubes.marching_cubes(raw.squeeze(), isolevel, truncation=3.0)
print('done', vertices.shape, triangles.shape)
# normalize vertex positions
vertices[:, :3] /= np.array([[tx.shape[0] - 1, ty.shape[0] - 1, tz.shape[0] - 1]])
# Rescale and translate
tx = tx.cpu().data.numpy()
ty = ty.cpu().data.numpy()
tz = tz.cpu().data.numpy()
scale = np.array([tx[-1] - tx[0], ty[-1] - ty[0], tz[-1] - tz[0]])
offset = np.array([tx[0], ty[0], tz[0]])
vertices[:, :3] = scale[np.newaxis, :] * vertices[:, :3] + offset
# Transform to metric units
vertices[:, :3] = vertices[:, :3] / config['data']['sc_factor'] - config['data']['translation']
if color_func is not None and not config['mesh']['render_color']:
if config['grid']['tcnn_encoding']:
vert_flat = (torch.from_numpy(vertices).to(bounding_box) - bounding_box[:, 0]) / (bounding_box[:, 1] - bounding_box[:, 0])
fn_color = get_batch_query_fn(color_func, 1)
chunk = 1024 * 64
raw = [fn_color(vert_flat, i, i + chunk).cpu().data.numpy() for i in range(0, vert_flat.shape[0], chunk)]
sh = vert_flat.shape
raw = np.concatenate(raw, 0).astype(np.float32)
color = np.reshape(raw, list(sh[:-1]) + [-1])
mesh = trimesh.Trimesh(vertices, triangles, process=False, vertex_colors=color)
elif color_func is not None and config['mesh']['render_color']:
print('rendering surface color')
mesh = trimesh.Trimesh(vertices, triangles, process=False)
vertex_normals = torch.from_numpy(mesh.vertex_normals)
fn_color = get_batch_query_fn(color_func, 2, device=bounding_box.device)
raw = [fn_color(torch.from_numpy(vertices), vertex_normals, i, i + chunk).cpu().data.numpy() for i in range(0, vertices.shape[0], chunk)]
sh = vertex_normals.shape
raw = np.concatenate(raw, 0).astype(np.float32)
color = np.reshape(raw, list(sh[:-1]) + [-1])
mesh = trimesh.Trimesh(vertices, triangles, process=False, vertex_colors=color)
else:
# Create mesh
mesh = trimesh.Trimesh(vertices, triangles, process=False)
os.makedirs(os.path.split(mesh_savepath)[0], exist_ok=True)
mesh.export(mesh_savepath)
print('Mesh saved')
return mesh
#### ####
#### SimpleRecon ####
def colormap_image(
image_1hw,
mask_1hw=None,
invalid_color=(0.0, 0, 0.0),
flip=True,
vmin=None,
vmax=None,
return_vminvmax=False,
colormap="turbo", # turbo:一种颜色映射,在低值区域使用暖色调(如红色、橙色),在高值区域使用冷色调(如蓝色、绿色),并且具有明显的亮度变化,以增强数据的对比度。
):
"""
Colormaps a one channel tensor using a matplotlib colormap.
Args:
image_1hw: the tensor to colomap.
mask_1hw: an optional float mask where 1.0 donates valid pixels.
colormap: the colormap to use. Default is turbo.
invalid_color: the color to use for invalid pixels.
flip: should we flip the colormap? True by default.
vmin: if provided uses this as the minimum when normalizing the tensor.
vmax: if provided uses this as the maximum when normalizing the tensor.
When either of vmin or vmax are None, they are computed from the
tensor.
return_vminvmax: when true, returns vmin and vmax.
Returns:
image_cm_3hw: image of the colormapped tensor.
vmin, vmax: returned when return_vminvmax is true.
"""
# 是否有一个额外的掩码mask_1hw, 用来确定有效区域
# 默认为None,即整个输入进来的张量image_1hw都是有效区域
# 否则,张量mask_1hw的每一个值为!=0的都是有效区域
valid_vals = image_1hw if mask_1hw is None else image_1hw[mask_1hw.bool()]
if vmin is None:
vmin = valid_vals.min() # 0
if vmax is None:
vmax = valid_vals.max() # 2.0256
# 将颜色映射turbo插值得到的颜色转换为 PyTorch 张量
cmap = torch.Tensor( # [256,3]
plt.cm.get_cmap(colormap)( # 获得颜色映射类型,这里是turbo
torch.linspace(0, 1, 256) # 创建256个0-1等间隔的值,代表颜色映射的插值
)[:, :3] # 只取前面3个通道,即rgb,不要第4个通道:透明度
).to(image_1hw.device)
if flip:
cmap = torch.flip(cmap, (0,)) # 在维度0上将上面这个张量进行翻转
h, w = image_1hw.shape[1:] # 图片的高宽 h:368, w:496
image_norm_1hw = (image_1hw - vmin) / (vmax - vmin) # 归一化图片的像素/深度等信息 [1, 368, 496]
image_int_1hw = (torch.clamp(image_norm_1hw * 255, 0, 255)).byte().long() # 将归一化值映射到0-255 [1, 368, 496]
image_cm_3hw = cmap[image_int_1hw.flatten(start_dim=1) # 将归一化的图像像素值映射为颜色,并将其转换为颜色图像。 [3, 368, 496]
].permute([0, 2, 1]).view([-1, h, w]) # permute()重新排列,view()将重新排列的颜色值变形为一个新的形状
if mask_1hw is not None:
invalid_color = torch.Tensor(invalid_color).view(3, 1, 1).to(image_1hw.device)
image_cm_3hw = image_cm_3hw * mask_1hw + invalid_color * (1 - mask_1hw)
if return_vminvmax:
return image_cm_3hw, vmin, vmax
else:
return image_cm_3hw