【代码阅读】Instance-Adaptive and Geometric-Aware Keypoint Learning for Category-Level 6D Object Pose Estimation

数据处理

./provider/create_dataloaders.py中创建Dataloader,分别可以使用camera_real,camera和housecat6d三种创建方式,如果使用camera_real方式创建的话,camera和real的比例为3:1。

深度图加载

由于该方法使用到了CAMERA25和REAL25两个数据集,而CAMERA25数据集是一个合成数据集,其深度图为合成深度图,所以需要进行处理,下面是合成深度图的读取方法(注意这里读取的是./data/camera_full_depths/中的图,和./data/camera/中的不一样):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def load_composed_depth(img_path):
""" Load depth image from img_path. """
img_path_ = img_path.replace('/data/camera/', '/data/camera_full_depths/')
depth_path = img_path_ + '_composed.png'
if os.path.exists(depth_path):
depth = cv2.imread(depth_path, -1)
if len(depth.shape) == 3:
# This is encoded depth image, let's convert
# NOTE: RGB is actually BGR in opencv
depth16 = depth[:, :, 1]*256 + depth[:, :, 2]
depth16 = np.where(depth16==32001, 0, depth16)
depth16 = depth16.astype(np.uint16)
elif len(depth.shape) == 2 and depth.dtype == 'uint16':
depth16 = depth
else:
assert False, '[ Error ]: Unsupported depth type.'
return depth16
else:
print("warning: No data")
return None

使用OpenCV读取./data/camera/train/00000/0000_depth.png,分别可视化其三个通道:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import cv2
import matplotlib.pyplot as plt

img_path = './data/camera/train/00000/0000'

depth_path = img_path + '_depth.png'
depth = cv2.imread(depth_path, -1)
plt.figure(figsize=(10, 5))
plt.subplot(131)
plt.title('depth[:, :, 0]')
plt.imshow(depth[:, :, 0])
plt.colorbar(shrink=0.5)
plt.subplot(132)
plt.title('depth[:, :, 1]')
plt.imshow(depth[:, :, 1])
plt.colorbar(shrink=0.5)
plt.subplot(133)
plt.title('depth[:, :, 2]')
plt.imshow(depth[:, :, 2])
plt.colorbar(shrink=0.5)
plt.show()
load_composed_depth_src

合成第1通道和第2通道:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import cv2
import numpy as np
import matplotlib.pyplot as plt

img_path = './data/camera/train/00000/0000'

depth_path = img_path + '_depth.png'
depth = cv2.imread(depth_path, -1)

depth_camera = depth[:, :, 1]*256 + depth[:, :, 2]
depth_camera = np.where(depth_camera==32001, 0, depth_camera)
depth_camera = depth_camera.astype(np.uint16)

plt.imshow(depth_camera)
plt.colorbar()
plt.show()
load_composed_depth_dist

读取./data/camera_full_depths/train/00000/0000_composed.png,直接可视化:

1
2
3
4
5
6
7
8
9
10
11
import cv2
import matplotlib.pyplot as plt

img_path = './data/camera_full_depths/train/00000/0000'

depth_path = img_path + '_composed.png'
depth = cv2.imread(depth_path, -1)

plt.imshow(depth)
plt.colorbar()
plt.show()
load_composed_depth

下面是真实深度图的读取方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def load_depth(img_path):
""" Load depth image from img_path. """
depth_path = img_path + '_depth.png'
depth = cv2.imread(depth_path, -1)
if len(depth.shape) == 3:
# This is encoded depth image, let's convert
# NOTE: RGB is actually BGR in opencv
depth16 = depth[:, :, 1]*256 + depth[:, :, 2]
depth16 = np.where(depth16==32001, 0, depth16)
depth16 = depth16.astype(np.uint16)
elif len(depth.shape) == 2 and depth.dtype == 'uint16':
depth16 = depth
else:
assert False, '[ Error ]: Unsupported depth type.'
return depth16

直接读取./data/real/train/scene_1/0000_depth.png并可视化:

1
2
3
4
5
6
7
8
9
10
11
import cv2
import matplotlib.pyplot as plt

img_path = './data/real/train/scene_1/0000'

depth_path = img_path + '_depth.png'
depth_real = cv2.imread(depth_path, -1)

plt.imshow(depth_real)
plt.colorbar()
plt.show()
load_depth

简单来说就是合成数据集读取的是./data/camera_full_depths/中的深度图,而真实数据集读取的是./data/real/中的深度图。

深度图补全

然后对深度图进行补全,分别对合成深度图和真实深度图进行补全:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import matplotlib.pyplot as plt
from utils.data_utils import fill_missing, load_depth, load_composed_depth

camera_img_path = './data/camera/train/00000/0000'
real_img_path = './data/real/train/scene_1/0000'

camera_depth = load_composed_depth(camera_img_path)
real_depth = load_depth(real_img_path)

fill_missing_camera = fill_missing(camera_depth, 1000.0, 1)
fill_missing_real = fill_missing(real_depth, 1000.0, 1)

plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.imshow(fill_missing_camera)
plt.colorbar(shrink=0.5)
plt.subplot(122)
plt.imshow(fill_missing_real)
plt.colorbar(shrink=0.5)
plt.show()
fill_missing_depth

mask加载

以下读取./data/camera/train/00000/0000_label.pkl中的内容:

1
2
with open(img_path + '_label.pkl', 'rb') as f:
gts = cPickle.load(f)

gts的内容为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
{'class_ids': [1, 2, 6, 1],
'bboxes': array([[ 50, 379, 222, 438],
[224, 107, 325, 222],
[ 21, 303, 83, 344],
[ 0, 78, 158, 187]], dtype=int32),
'scales': array([0.3005802 , 0.21199934, 0.17190282, 0.42889243], dtype=float32),
'sizes': array([[0.238512, 0.941394, 0.238512],
[0.691974, 0.20577 , 0.691974],
[0.685634, 0.549036, 0.477982],
[0.210626, 0.926818, 0.31088 ]], dtype=float32),
'rotations': array([[[ 0.98163134, 0.01207083, -0.19040541],
[ 0.14415757, -0.7006575 , 0.69878304],
[-0.12497408, -0.7133957 , -0.6895274 ]],

[[ 0.1414129 , 0.01162317, 0.98988247],
[-0.705418 , -0.7003605 , 0.10899841],
[ 0.69454145, -0.7136947 , -0.09084082]],

[[ 0.02342262, 0.01182115, 0.9996558 ],
[-0.71319616, -0.7005187 , 0.02499446],
[ 0.700573 , -0.7135361 , -0.00797719]],

[[-0.98874557, 0.01170614, 0.14914814],
[-0.11453401, -0.70061237, -0.7042899 ],
[ 0.09625051, -0.7134461 , 0.69406813]]], dtype=float32),
'translations': array([[ 0.1340191 , -0.14495842, 0.8599784 ],
[-0.20857327, 0.04460018, 0.7819705 ],
[ 0.0061691 , -0.40006354, 1.2049764 ],
[-0.31816682, -0.3390898 , 0.9828115 ]], dtype=float32),
'instance_ids': [2, 3, 4, 9],
'model_list': ['ab6792cddc7c4c83afbf338b16b43f53',
'7d7bdea515818eb844638317e9e4ff18',
'73b8b6456221f4ea20d3c05c08e26f',
'a1275bd03ab15100f6dbe3dc17d6cdf7']}

读取mask

1
mask = cv2.imread(img_path + '_mask.png')[:, :, 2]

这里只取了第2通道,因为第2通道携带了类别信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
mask = cv2.imread('./data/camera/train/00000/0000_mask.png')

unique_ids_0 = np.unique(mask[:, :, 0])
print("第0通道实例ID列表:", unique_ids_0)
# 第0通道实例ID列表: [ 0 255]

unique_ids_1 = np.unique(mask[:, :, 1])
print("第1通道实例ID列表:", unique_ids_1)
# 第1通道实例ID列表: [ 0 255]

unique_ids_2 = np.unique(mask[:, :, 2])
print("第2通道实例ID列表:", unique_ids_2)
# 第2通道实例ID列表: [ 1 2 3 4 9 255]

与上方的instance_ids对应。

去除非目标物体的mask

首先随机选取一个物体(注意,训练时会从图像中所有物体中随机返回一个,而测试时则会全部返回):

1
idx = np.random.randint(0, num_instance)

使用get_bbox

1
rmin, rmax, cmin, cmax = get_bbox(gts['bboxes'][idx])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def get_bbox(bbox):
""" Compute square image crop window. """
y1, x1, y2, x2 = bbox
img_width = 480
img_length = 640
window_size = (max(y2 - y1, x2 - x1) // 40 + 1) * 40
window_size = min(window_size, 440)
center = [(y1 + y2) // 2, (x1 + x2) // 2]
rmin = center[0] - int(window_size / 2)
rmax = center[0] + int(window_size / 2)
cmin = center[1] - int(window_size / 2)
cmax = center[1] + int(window_size / 2)
if rmin < 0:
delt = -rmin
rmin = 0
rmax += delt
if cmin < 0:
delt = -cmin
cmin = 0
cmax += delt
if rmax > img_width:
delt = rmax - img_width
rmax = img_width
rmin -= delt
if cmax > img_length:
delt = cmax - img_length
cmax = img_length
cmin -= delt
return rmin, rmax, cmin, cmax

以原boundingbox的中心为中心,生成长宽为40的倍数的boundingbox,并考虑了结果boundingbox超出图像范围的情况。

使用与操作去除其余物体的mask:

1
2
mask = np.equal(mask, gts['instance_ids'][idx])
mask = np.logical_and(mask , depth > 0)

从物体上采样

在后续得到物体点云后,需要从点云中采样固定数量的点以调整为网络需要的输入维度,在这一步实现。

1
choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]

从mask中截取目标物体的一部分,然后展平,得到其非零值的下标。

采样到固定数量(配置文件中为1024):

1
2
3
4
5
6
if len(choose) <= self.sample_num: # 1024
choose_idx = np.random.choice(len(choose), self.sample_num)
else:
choose_idx = np.random.choice(len(choose), self.sample_num, replace=False)

choose = choose[choose_idx]

将深度图转换为点云

获取内参:

1
cam_fx, cam_fy, cam_cx, cam_cy = self.intrinsics

(深度)归一化:

1
pts2 = depth.copy() / self.norm_scale

将像素坐标系中的xy坐标转换为相机坐标系中的xy坐标:

1
2
3
4
self.xmap = np.array([[i for i in range(640)] for j in range(480)])
self.ymap = np.array([[j for i in range(640)] for j in range(480)])
pts0 = (self.xmap - cam_cx) * pts2 / cam_fx
pts1 = (self.ymap - cam_cy) * pts2 / cam_fy

具体原理见:针孔相机成像模型

合并并裁剪:

1
2
pts = np.transpose(np.stack([pts0, pts1, pts2]), (1, 2, 0)).astype(np.float32)
pts = pts[rmin:rmax, cmin:cmax, :].reshape((-1, 3))[choose, :]

RGB

这里会将目标图像使用rmin, rmax, cmin, cmax进行裁剪,然后resize(代码中为\(224 \times 224\)),那么相应的采样下标也需要修改。

除此之外,使用OpenCV读入RGB时,其维度顺序为HWC,在经过transforms.ToTensor()后,维度顺序会变为CHW。

修改采样下标

1
2
3
4
5
6
7
8
crop_w = rmax - rmin # 原本crop mask的宽度
ratio = self.img_size / crop_w # 缩放比例
col_idx = choose % crop_w # 原本列索引的相对位置
row_idx = choose // crop_w # 原本行索引的相对位置
choose = (np.floor(row_idx * ratio) * self.img_size + np.floor(col_idx * ratio)).astype(np.int64)
# np.floor(row_idx * ratio) 新行所在的位置
# np.floor(col_idx * ratio) 新列所在的位置
# np.floor(row_idx * ratio) * self.img_size + np.floor(col_idx * ratio) 新位置的索引

读取物体模型点云

首先是全部物体的模型点云,这里读取的是./data/obj_models/camera_train.pkl中的数据:

1
2
3
self.models = {}
with open(os.path.join(self.data_dir, model_path), 'rb') as f:
self.models.update(cPickle.load(f))

该文件以字典的形式存储了物体的点云信息,可以通过键取出对应的点云:

1
model = self.models[gts['model_list'][idx]].astype(np.float32)

取出物体的平移、旋转和大小:

1
2
3
4
5
translation = gts['translations'][idx].astype(np.float32) # 3 物体坐标系到相机坐标系的平移
rotation = gts['rotations'][idx].astype(np.float32) # 3, 3 物体坐标系到相机坐标系的旋转
size = gts['scales'][idx] * gts['sizes'][idx].astype(np.float32) # 3
# gts['scales'][idx] 1 物体缩放比例
# gts['sizes'][idx] 3 物体长宽高

处理对称物体

1
2
3
4
5
6
7
8
9
if cat_id in self.sym_ids:
theta_x = rotation[0, 0] + rotation[2, 2]
theta_y = rotation[0, 2] - rotation[2, 0]
r_norm = math.sqrt(theta_x**2 + theta_y**2)
s_map = np.array([[theta_x/r_norm, 0.0, -theta_y/r_norm],
[0.0, 1.0, 0.0 ],
[theta_y/r_norm, 0.0, theta_x/r_norm]])
rotation = rotation @ s_map
# 绕Y轴旋转,这里假设Y轴朝上,Z轴朝前,X轴朝右,为右手系

Y轴朝上,所以第2列为\([0, 1, 0]\)

网络架构

RGB特征

使用DINOv2提取RGB特征:

1
2
3
self.rgb_extractor = torch.hub.load('facebookresearch/dinov2','dinov2_vits14')
for param in self.rgb_extractor.parameters():
param.requires_grad = False

还使用了一个1d卷积:

1
2
3
4
5
self.feature_mlp = nn.Sequential(
nn.Conv1d(384, 128, 1),
)

rgb_local = self.feature_mlp(dino_feature)

挑选RGB特征

在数据处理时,生成了一个choose变量,这里要用该变量将RGB特征从\(\mathbb{R}^{b \times 128 \times (196 \times 196)}\)采样为\(\mathbb{R}^{b \times 128 \times 1024}\)

加噪

1
2
3
if self.training:
delta_r, delta_t, delta_s = generate_augmentation(b)
pts = (pts - delta_t) / delta_s.unsqueeze(2) @ delta_r

其中,生成噪声的函数为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def generate_augmentation(batchsize):
delta_t = torch.rand(batchsize, 1, 3).cuda() # b, 1, 3
delta_t = delta_t.uniform_(-0.02, 0.02) # 将值重新采样到[-0.02, 0.02]范围

angle_r = torch.randn(batchsize, 3) # b, 3
angle_r.uniform_(-20, 20) # 将值重新采样到[-20, 20]范围
angle_r = angle_r / 180 * torch.pi # b, 3 将角度转换为弧度

delta_r_x = torch.eye(3).unsqueeze(0).repeat(batchsize, 1, 1) # b, 3, 3
delta_r_y = torch.eye(3).unsqueeze(0).repeat(batchsize, 1, 1)
delta_r_z = torch.eye(3).unsqueeze(0).repeat(batchsize, 1, 1)

# 绕X轴旋转
delta_r_x[:, 1, 1] = torch.cos(angle_r[:, 0])
delta_r_x[:, 1, 2] = -torch.sin(angle_r[:, 0])
delta_r_x[:, 2, 1] = torch.sin(angle_r[:, 0])
delta_r_x[:, 2, 2] = torch.cos(angle_r[:, 0])

# 绕Y轴旋转
delta_r_y[:, 0, 0] = torch.cos(angle_r[:, 1])
delta_r_y[:, 0, 2] = torch.sin(angle_r[:, 1])
delta_r_y[:, 2, 0] = -torch.sin(angle_r[:, 1])
delta_r_y[:, 2, 2] = torch.cos(angle_r[:, 1])

# 绕Z轴旋转
delta_r_z[:, 0, 0] = torch.cos(angle_r[:, 2])
delta_r_z[:, 0, 1] = -torch.sin(angle_r[:, 2])
delta_r_z[:, 1, 0] = torch.sin(angle_r[:, 2])
delta_r_z[:, 1, 1] = torch.cos(angle_r[:, 2])

# 组合旋转矩阵
delta_r = torch.bmm(torch.bmm(delta_r_x, delta_r_y), delta_r_z).cuda()

delta_s = torch.rand(batchsize, 1).cuda() # b, 1
delta_s = delta_s.uniform_(0.8, 1.2) # 将值重新采样到[0.8, 1.2]范围

return delta_r, delta_t, delta_s

模型会预测加噪后的位姿,然后在Loss阶段,会使用生成的噪声去除噪声,以实现数据增强。

点云特征

原文中说使用PointNet++来提取点云特征,但是暂时没有阅读过和PointNet++相关的论文,所以暂时不详细写。

IAKD

该模块以RGB特征和点云特征为输入,将点云特征和RGB特征进行拼接后作为\(KV\)、将可训练的一个查询向量作为\(Q\),执行交叉注意力,返回处理后的查询向量和注意力图。

然后使用查询向量和输入特征做矩阵乘法,得到热图,最后返回查询向量和热图。

热图将拼接后的点云特征和RGB特征进一步压缩。

GAFA

首先使用一堆卷积和一堆全连接堆叠成GAFA块,然后使用两个GAFA块组成GAFA模块,最后返回关键点特征。

训练

这里训练和测试都使用的是Gorilla-Lab-SCUT/gorilla-core中的包。

测试