SAM-Guided Masked Token Prediction for 3D Scene Understanding

Zhimin Chen,Liang Yang,Yingwei Li,Longlong Jing,Bing Li
2024-10-17
Abstract:Foundation models have significantly enhanced 2D task performance, and recent works like Bridge3D have successfully applied these models to improve 3D scene understanding through knowledge distillation, marking considerable advancements. Nonetheless, challenges such as the misalignment between 2D and 3D representations and the persistent long-tail distribution in 3D datasets still restrict the effectiveness of knowledge distillation from 2D to 3D using foundation models. To tackle these issues, we introduce a novel SAM-guided tokenization method that seamlessly aligns 3D transformer structures with region-level knowledge distillation, replacing the traditional KNN-based tokenization techniques. Additionally, we implement a group-balanced re-weighting strategy to effectively address the long-tail problem in knowledge distillation. Furthermore, inspired by the recent success of masked feature prediction, our framework incorporates a two-stage masked token prediction process in which the student model predicts both the global embeddings and the token-wise local embeddings derived from the teacher models trained in the first stage. Our methodology has been validated across multiple datasets, including SUN RGB-D, ScanNet, and S3DIS, for tasks like 3D object detection and semantic segmentation. The results demonstrate significant improvements over current State-of-the-art self-supervised methods, establishing new benchmarks in this field.
Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The main problems that this paper attempts to solve are the challenges faced in 3D scene understanding, especially several key issues in the knowledge transfer process from 2D base models to 3D scene understanding: 1. **Inconsistency between 2D and 3D Representations**: In the traditional knowledge transfer methods from 2D to 3D, due to the inconsistency between 2D and 3D representations, the performance declines. Specifically, when using the K - Nearest Neighbor (KNN) method for point cloud segmentation, points from different regions may be wrongly grouped together, resulting in information conflicts in the 3D network. 2. **Long - Tail Distribution Problem of 3D Datasets**: 3D datasets usually have the long - tail distribution characteristic, that is, the number of samples in some categories is much larger than that in other categories. This imbalance will cause the model to rely too much on samples of common categories and perform poorly on samples of rare categories. To overcome these problems, the author proposes a mask token prediction method guided by SAM (Segment Anything Model), which mainly includes the following aspects: - **SAM - Guided Point Cloud Segmentation Method**: Use the masks generated by SAM to guide point cloud segmentation, ensure the consistency between each 3D point cloud region and the corresponding 2D region features, and avoid the information conflicts brought by the KNN method. - **Population - Balanced Re - weighting Strategy**: By introducing a population - balanced re - weighting strategy, adjust the distillation loss weights between 2D and 3D representations to solve the long - tail distribution problem of 3D datasets and improve the model's ability to handle rare categories. - **Two - Stage Mask Token Prediction Framework**: In the first stage, transfer the knowledge of the 2D base model to the 3D network through dense region - level knowledge distillation; in the second stage, the student model predicts the global and local embeddings generated by the teacher model according to the visible 3D input part, thereby learning more consistent and context - related representations. Through these innovative methods, the author has verified the effectiveness of their method on multiple datasets, including SUN RGB - D, ScanNet and S3DIS, showing significant improvements in 3D object detection and semantic segmentation tasks.