MOOSS: Mask-Enhanced Temporal Contrastive Learning for Smooth State Evolution in Visual Reinforcement Learning

Jiarui Sun,M. Ugur Akcal,Wei Zhang,Girish Chowdhary
2024-09-03
Abstract:In visual Reinforcement Learning (RL), learning from pixel-based observations poses significant challenges on sample efficiency, primarily due to the complexity of extracting informative state representations from high-dimensional data. Previous methods such as contrastive-based approaches have made strides in improving sample efficiency but fall short in modeling the nuanced evolution of states. To address this, we introduce MOOSS, a novel framework that leverages a temporal contrastive objective with the help of graph-based spatial-temporal masking to explicitly model state evolution in visual RL. Specifically, we propose a self-supervised dual-component strategy that integrates (1) a graph construction of pixel-based observations for spatial-temporal masking, coupled with (2) a multi-level contrastive learning mechanism that enriches state representations by emphasizing temporal continuity and change of states. MOOSS advances the understanding of state dynamics by disrupting and learning from spatial-temporal correlations, which facilitates policy learning. Our comprehensive evaluation on multiple continuous and discrete control benchmarks shows that MOOSS outperforms previous state-of-the-art visual RL methods in terms of sample efficiency, demonstrating the effectiveness of our method. Our code is released at <a class="link-external link-https" href="https://github.com/jsun57/MOOSS" rel="external noopener nofollow">this https URL</a>.
Computer Vision and Pattern Recognition,Machine Learning
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to solve the problem of low sample efficiency in visual reinforcement learning (Visual RL). Specifically, visual RL learns from pixel - based observations, which brings two main challenges: 1. **Complexity of high - dimensional data**: It is very difficult to extract useful state representations from high - dimensional visual data (such as image sequences). Although traditional contrastive learning methods improve sample efficiency to a certain extent, they fail to fully model the subtle changes in state evolution. 2. **Insufficient modeling of state evolution**: Existing methods usually only consider the binary relationship between positive and negative samples, ignoring the characteristic that the state gradually changes over time. In addition, unlike video models, the observation encoder in visual RL can only process one observation at a time, making it difficult to capture temporal evolution. To solve these problems, the author proposes a new framework MOOSS (Mask - Enhanced Temporal Contrastive Learning for Smooth State Evolution). Its core idea is to explicitly model state evolution in the following two ways: 1. **Graph - based spatio - temporal masks**: Consider pixel - based observations as spatio - temporal graphs and apply random walk mask techniques to generate contrastive samples. This method breaks continuous information blocks, increases the difficulty of the pre - training task, and thus forces the model to understand the spatio - temporal dynamics of the observed data more deeply. 2. **Multi - level temporal contrastive objective**: Design a multi - level temporal contrastive loss function to encourage the model to focus on progressive and evolving state changes at different time distances, rather than simply distinguishing between positive and negative samples. By combining these two methods, MOOSS can significantly improve sample efficiency in multiple continuous and discrete control benchmark tests, surpassing existing visual RL methods. #### Formula summary - **InfoNCE Loss**: \[ L_q = -\mathbb{E} \left[ \log \frac{\exp(\text{sim}(q, k^+)/\tau)}{\sum_{k \in K} \exp(\text{sim}(q, k)/\tau)} \right] \] where \(\text{sim}(q, k)\) represents the similarity of the sample pair, and \(\tau\) is the temperature parameter. - **Multi - level temporal contrastive loss**: \[ L_l^q = -\log \frac{\sum_{k^{\Delta = l}} \exp(\text{sim}(q, k)/\tau_l)}{\sum_{k^{\Delta \geq l} \cup k'} \exp(\text{sim}(q, k)/\tau_l)} \] where \(k^{\Delta = l}\) represents the key sample that is \(l\) steps away from the query sample \(q\), and \(\tau_l < \tau_{l + 1}\). These formulas ensure that the model can effectively capture the progressive changes of the state over time, thereby improving the sample efficiency of visual RL.