Generalizable Visual Reinforcement Learning with Segment Anything Model

Ziyu Wang1*   Yanjie Ze2*   Yifei Sun23   Zhecheng Yuan124   Huazhe Xu124 1Tsinghua University, IIIS   2Shanghai Qi Zhi Institute   3Tongji University   4Shanghai AI Lab  
*Equal contribution


Abstract

Learning policies that can generalize to unseen environments is a fundamental challenge in visual reinforcement learning (RL). While most current methods focus on acquiring robust visual representations through auxiliary supervision, pre-training, or data augmentation, the potential of modern vision foundation models remains underleveraged. In this work, we introduce Segment Anything Model for Generalizable visual RL (SAM-G), a novel framework that leverages the promptable segmentation ability of Segment Anything Model (SAM) to enhance the generalization capabilities of visual RL agents. We utilize image features from DINOv2 and SAM to find correspondence as point prompts to SAM, and then SAM produces high-quality masked images for agents directly. Evaluated across 8 DMControl tasks and 3 Adroit tasks, SAM-G significantly improves the visual generalization ability without altering the RL agents' architecture but merely their observations. Notably, SAM-G achieves 44% and 29% relative improvements on the challenging video hard setting on DMControl and Adroit respectively, compared to state-of-the-art methods.

Visual generalization results across 2 domains and 4 settings. SAM-G could robustly improve the visual generalization ability of visual RL agents such as DrQ-v2 and PIE-G. Notably, in the challenging video hard setting, SAM-G surpasses previous state-of-the-art method PIE-G with 44% and 29% relative improvement on DMControl and Adroit respectively.

Generalization

SAM-G could segment and generalize to unseen environments. We outline and colorize the mask predicted from SAM.


Walker Walk

Finger Spin
Ball in Cup

Cheetah Run

Cartpole Swingup

Hopper Stand

Walker Stand
Adroit Door
Adroit Hammer

Adroit Pen

Visualization of Point Prompts

We visualize point prompts that are found by correspondence. Postive points are green stars and negative points are red forks.


Walker Walk

Finger Spin
Ball in Cup

Cheetah Run

Cartpole Swingup

Hopper Stand

Walker Stand
Adroit Door
Adroit Hammer

Adroit Pen

Method

We use point features (obtained by vision foundation models) from the training environment to find correspondence in the test image and obtain point prompts. Then the mask decoder iteratively refines the predicted mask given point prompts.

Wall Time

The original ViT model from SAM imposes a large computational overhead when encoding images. To address this issue, we incorporate the EfficientViT architecture. For comparison, we evaluate several visual RL algorithms including DrQ-v2, SVEA, and PIE-G. We also replace EfficientViT with the original SAM, EfficientSAM, and MobileSAM for comparison. The wall time is shown below.


Citation

If you use our method or code in your research, please consider citing the paper as follows:

@article{Wang2023SAMG, title={Generalizable Visual Reinforcement Learning with Segment Anything Model}, author={Ziyu Wang and Yanjie Ze and Yifei Sun and Zhecheng Yuan and Huazhe Xu}, journal={arXiv}, year={2023}, }