Spaces:
Runtime error
Runtime error
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File pram -> tracker | |
| @IDE PyCharm | |
| @Author fx221@cam.ac.uk | |
| @Date 29/02/2024 16:58 | |
| ==================================================''' | |
| import time | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import pycolmap | |
| from localization.frame import Frame | |
| from localization.base_model import dynamic_load | |
| import localization.matchers as matchers | |
| from localization.match_features_batch import confs as matcher_confs | |
| from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches | |
| from tools.common import resize_img | |
| class Tracker: | |
| def __init__(self, locMap, matcher, config): | |
| self.locMap = locMap | |
| self.matcher = matcher | |
| self.config = config | |
| self.loc_config = config['localization'] | |
| self.lost = True | |
| self.curr_frame = None | |
| self.last_frame = None | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| Model = dynamic_load(matchers, 'nearest_neighbor') | |
| self.nn_matcher = Model(matcher_confs['NNM']['model']).eval().to(device) | |
| def run(self, frame: Frame): | |
| print('Start tracking...') | |
| show = self.config['localization']['show'] | |
| self.curr_frame = frame | |
| ref_img = self.last_frame.image | |
| curr_img = self.curr_frame.image | |
| q_kpts = frame.keypoints | |
| t_start = time.time() | |
| ret = self.track_last_frame(curr_frame=self.curr_frame, last_frame=self.last_frame) | |
| self.curr_frame.time_loc = self.curr_frame.time_loc + time.time() - t_start | |
| if show: | |
| curr_matched_kpts = ret['matched_keypoints'] | |
| ref_matched_kpts = ret['matched_ref_keypoints'] | |
| img_loc_matching = plot_matches(img1=curr_img, img2=ref_img, | |
| pts1=curr_matched_kpts, | |
| pts2=ref_matched_kpts, | |
| inliers=np.array([True for i in range(curr_matched_kpts.shape[0])]), | |
| radius=9, line_thickness=3) | |
| self.curr_frame.image_matching = img_loc_matching | |
| q_ref_img_matching = resize_img(img_loc_matching, nh=512) | |
| if not ret['success']: | |
| show_text = 'Tracking FAILED!' | |
| img_inlier = vis_inlier(img=curr_img, kpts=curr_matched_kpts, | |
| inliers=[False for i in range(curr_matched_kpts.shape[0])], radius=9 + 2, | |
| thickness=2) | |
| q_img_inlier = cv2.putText(img=img_inlier, text=show_text, org=(30, 30), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
| thickness=2, lineType=cv2.LINE_AA) | |
| q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) | |
| cv2.imshow('loc', q_img_loc) | |
| key = cv2.waitKey(self.loc_config['show_time']) | |
| if key == ord('q'): | |
| cv2.destroyAllWindows() | |
| exit(0) | |
| return False | |
| ret['matched_scene_name'] = self.last_frame.scene_name | |
| success = self.verify_and_update(q_frame=self.curr_frame, ret=ret) | |
| if not success: | |
| return False | |
| if ret['num_inliers'] < 256: | |
| # refinement is necessary for tracking last frame | |
| t_start = time.time() | |
| ret = self.locMap.sub_maps[self.last_frame.matched_scene_name].refine_pose(self.curr_frame, | |
| refinement_method= | |
| self.loc_config[ | |
| 'refinement_method']) | |
| self.curr_frame.time_ref = self.curr_frame.time_ref + time.time() - t_start | |
| ret['matched_scene_name'] = self.last_frame.scene_name | |
| success = self.verify_and_update(q_frame=self.curr_frame, ret=ret) | |
| if show: | |
| q_err, t_err = self.curr_frame.compute_pose_error() | |
| num_matches = ret['matched_keypoints'].shape[0] | |
| num_inliers = ret['num_inliers'] | |
| show_text = 'Tracking, k/m/i: {:d}/{:d}/{:d}'.format(q_kpts.shape[0], num_matches, num_inliers) | |
| q_img_inlier = vis_inlier(img=curr_img, kpts=ret['matched_keypoints'], inliers=ret['inliers'], | |
| radius=9 + 2, thickness=2) | |
| q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
| thickness=2, lineType=cv2.LINE_AA) | |
| show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err) | |
| q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
| thickness=2, lineType=cv2.LINE_AA) | |
| self.curr_frame.image_inlier = q_img_inlier | |
| q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) | |
| cv2.imshow('loc', q_img_loc) | |
| key = cv2.waitKey(self.loc_config['show_time']) | |
| if key == ord('q'): | |
| cv2.destroyAllWindows() | |
| exit(0) | |
| self.lost = success | |
| return success | |
| def verify_and_update(self, q_frame: Frame, ret: dict): | |
| num_matches = ret['matched_keypoints'].shape[0] | |
| num_inliers = ret['num_inliers'] | |
| q_frame.qvec = ret['qvec'] | |
| q_frame.tvec = ret['tvec'] | |
| q_err, t_err = q_frame.compute_pose_error() | |
| if num_inliers < self.loc_config['min_inliers']: | |
| print_text = 'Failed due to insufficient {:d} inliers, q_err: {:.2f}, t_err: {:.2f}'.format( | |
| ret['num_inliers'], q_err, t_err) | |
| print(print_text) | |
| q_frame.tracking_status = False | |
| q_frame.clear_localization_track() | |
| return False | |
| else: | |
| print_text = 'Succeed! Find {}/{} 2D-3D inliers,q_err: {:.2f}, t_err: {:.2f}'.format( | |
| num_inliers, num_matches, q_err, t_err) | |
| print(print_text) | |
| q_frame.tracking_status = True | |
| self.update_current_frame(curr_frame=q_frame, ret=ret) | |
| return True | |
| def update_current_frame(self, curr_frame: Frame, ret: dict): | |
| curr_frame.qvec = ret['qvec'] | |
| curr_frame.tvec = ret['tvec'] | |
| curr_frame.matched_scene_name = ret['matched_scene_name'] | |
| curr_frame.reference_frame_id = ret['reference_frame_id'] | |
| inliers = np.array(ret['inliers']) | |
| curr_frame.matched_keypoints = ret['matched_keypoints'][inliers] | |
| curr_frame.matched_xyzs = ret['matched_xyzs'][inliers] | |
| curr_frame.matched_point3D_ids = ret['matched_point3D_ids'][inliers] | |
| curr_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inliers] | |
| curr_frame.matched_sids = ret['matched_sids'][inliers] | |
| def track_last_frame(self, curr_frame: Frame, last_frame: Frame): | |
| curr_kpts = curr_frame.keypoints[:, :2] | |
| curr_scores = curr_frame.keypoints[:, 2] | |
| curr_descs = curr_frame.descriptors | |
| curr_kpt_ids = np.arange(curr_kpts.shape[0]) | |
| last_kpts = last_frame.keypoints[:, :2] | |
| last_scores = last_frame.keypoints[:, 2] | |
| last_descs = last_frame.descriptors | |
| last_xyzs = last_frame.xyzs | |
| last_point3D_ids = last_frame.point3D_ids | |
| last_sids = last_frame.seg_ids | |
| # ''' | |
| indices = self.matcher({ | |
| 'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(), | |
| 'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(), | |
| 'scores0': torch.from_numpy(curr_scores)[None].cuda().float(), | |
| 'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height), | |
| 'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(), | |
| 'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(), | |
| 'scores1': torch.from_numpy(last_scores)[None].cuda().float(), | |
| 'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height), | |
| })['matches0'][0].cpu().numpy() | |
| ''' | |
| indices = self.nn_matcher({ | |
| 'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None], | |
| 'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None], | |
| })['matches0'][0].cpu().numpy() | |
| ''' | |
| valid = (indices >= 0) | |
| matched_point3D_ids = last_point3D_ids[indices[valid]] | |
| point3D_mask = (matched_point3D_ids >= 0) | |
| matched_point3D_ids = matched_point3D_ids[point3D_mask] | |
| matched_sids = last_sids[indices[valid]][point3D_mask] | |
| matched_kpts = curr_kpts[valid][point3D_mask] | |
| matched_kpt_ids = curr_kpt_ids[valid][point3D_mask] | |
| matched_xyzs = last_xyzs[indices[valid]][point3D_mask] | |
| matched_last_kpts = last_kpts[indices[valid]][point3D_mask] | |
| print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0], | |
| last_kpts.shape[0])) | |
| # print('tracking: ', matched_kpts.shape, matched_xyzs.shape) | |
| ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs, | |
| curr_frame.camera, | |
| estimation_options={ | |
| "ransac": {"max_error": self.config['localization']['threshold']}}, | |
| refinement_options={}, | |
| # max_error_px=self.config['localization']['threshold'] | |
| ) | |
| if ret is None: | |
| ret = {'success': False, } | |
| else: | |
| ret['success'] = True | |
| ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] | |
| ret['tvec'] = ret['cam_from_world'].translation | |
| ret['matched_keypoints'] = matched_kpts | |
| ret['matched_keypoint_ids'] = matched_kpt_ids | |
| ret['matched_ref_keypoints'] = matched_last_kpts | |
| ret['matched_xyzs'] = matched_xyzs | |
| ret['matched_point3D_ids'] = matched_point3D_ids | |
| ret['matched_sids'] = matched_sids | |
| ret['reference_frame_id'] = last_frame.reference_frame_id | |
| ret['matched_scene_name'] = last_frame.matched_scene_name | |
| return ret | |
| def track_last_frame_fast(self, curr_frame: Frame, last_frame: Frame): | |
| curr_kpts = curr_frame.keypoints[:, :2] | |
| curr_scores = curr_frame.keypoints[:, 2] | |
| curr_descs = curr_frame.descriptors | |
| curr_kpt_ids = np.arange(curr_kpts.shape[0]) | |
| last_point3D_ids = last_frame.point3D_ids | |
| point3D_mask = (last_point3D_ids >= 0) | |
| last_kpts = last_frame.keypoints[:, :2][point3D_mask] | |
| last_scores = last_frame.keypoints[:, 2][point3D_mask] | |
| last_descs = last_frame.descriptors[point3D_mask] | |
| last_xyzs = last_frame.xyzs[point3D_mask] | |
| last_sids = last_frame.seg_ids[point3D_mask] | |
| minx = np.min(last_kpts[:, 0]) | |
| maxx = np.max(last_kpts[:, 0]) | |
| miny = np.min(last_kpts[:, 1]) | |
| maxy = np.max(last_kpts[:, 1]) | |
| curr_mask = (curr_kpts[:, 0] >= minx) * (curr_kpts[:, 0] <= maxx) * (curr_kpts[:, 1] >= miny) * ( | |
| curr_kpts[:, 1] <= maxy) | |
| curr_kpts = curr_kpts[curr_mask] | |
| curr_scores = curr_scores[curr_mask] | |
| curr_descs = curr_descs[curr_mask] | |
| curr_kpt_ids = curr_kpt_ids[curr_mask] | |
| # ''' | |
| indices = self.matcher({ | |
| 'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(), | |
| 'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(), | |
| 'scores0': torch.from_numpy(curr_scores)[None].cuda().float(), | |
| 'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height), | |
| 'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(), | |
| 'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(), | |
| 'scores1': torch.from_numpy(last_scores)[None].cuda().float(), | |
| 'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height), | |
| })['matches0'][0].cpu().numpy() | |
| ''' | |
| indices = self.nn_matcher({ | |
| 'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None], | |
| 'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None], | |
| })['matches0'][0].cpu().numpy() | |
| ''' | |
| valid = (indices >= 0) | |
| matched_point3D_ids = last_point3D_ids[indices[valid]] | |
| matched_sids = last_sids[indices[valid]] | |
| matched_kpts = curr_kpts[valid] | |
| matched_kpt_ids = curr_kpt_ids[valid] | |
| matched_xyzs = last_xyzs[indices[valid]] | |
| matched_last_kpts = last_kpts[indices[valid]] | |
| print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0], | |
| last_kpts.shape[0])) | |
| # print('tracking: ', matched_kpts.shape, matched_xyzs.shape) | |
| ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs, | |
| curr_frame.camera._asdict(), | |
| max_error_px=self.config['localization']['threshold']) | |
| ret['matched_keypoints'] = matched_kpts | |
| ret['matched_keypoint_ids'] = matched_kpt_ids | |
| ret['matched_ref_keypoints'] = matched_last_kpts | |
| ret['matched_xyzs'] = matched_xyzs | |
| ret['matched_point3D_ids'] = matched_point3D_ids | |
| ret['matched_sids'] = matched_sids | |
| ret['reference_frame_id'] = last_frame.reference_frame_id | |
| ret['matched_scene_name'] = last_frame.matched_scene_name | |
| return ret | |
| def match_frame(self, frame: Frame, reference_frame: Frame): | |
| print('match: ', frame.keypoints.shape, reference_frame.keypoints.shape) | |
| matches = self.matcher({ | |
| 'descriptors0': torch.from_numpy(frame.descriptors)[None].cuda().float(), | |
| 'keypoints0': torch.from_numpy(frame.keypoints[:, :2])[None].cuda().float(), | |
| 'scores0': torch.from_numpy(frame.keypoints[:, 2])[None].cuda().float(), | |
| 'image_shape0': (1, 3, frame.image_size[0], frame.image_size[1]), | |
| # 'descriptors0': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(), | |
| # 'keypoints0': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(), | |
| # 'scores0': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(), | |
| # 'image_shape0': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]), | |
| 'descriptors1': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(), | |
| 'keypoints1': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(), | |
| 'scores1': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(), | |
| 'image_shape1': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]), | |
| })['matches0'][0].cpu().numpy() | |
| ids1 = np.arange(matches.shape[0]) | |
| ids2 = matches | |
| ids1 = ids1[matches >= 0] | |
| ids2 = ids2[matches >= 0] | |
| mask_p3ds = reference_frame.points3d_mask[ids2] | |
| ids1 = ids1[mask_p3ds] | |
| ids2 = ids2[mask_p3ds] | |
| return ids1, ids2 | |