Gated Domain Units for Multi-source Domain Generalization

Simon Föll,Alina Dubatovka,Eugen Ernst,Siu Lun Chau,Martin Maritsch,Patrik Okanovic,Gudrun Thäter,Joachim M. Buhmann,Felix Wortmann,Krikamol Muandet
DOI: https://doi.org/10.48550/arXiv.2206.12444
2023-05-16
Abstract:The phenomenon of distribution shift (DS) occurs when a dataset at test time differs from the dataset at training time, which can significantly impair the performance of a machine learning model in practical settings due to a lack of knowledge about the data's distribution at test time. To address this problem, we postulate that real-world distributions are composed of latent Invariant Elementary Distributions (I.E.D) across different domains. This assumption implies an invariant structure in the solution space that enables knowledge transfer to unseen domains. To exploit this property for domain generalization, we introduce a modular neural network layer consisting of Gated Domain Units (GDUs) that learn a representation for each latent elementary distribution. During inference, a weighted ensemble of learning machines can be created by comparing new observations with the representations of each elementary distribution. Our flexible framework also accommodates scenarios where explicit domain information is not present. Extensive experiments on image, text, and graph data show consistent performance improvement on out-of-training target domains. These findings support the practicality of the I.E.D assumption and the effectiveness of GDUs for domain generalisation.
Machine Learning
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to solve the **Distribution Shift (DS)** problem encountered by machine - learning models in practical applications. Specifically, when the distribution of test data is inconsistent with that of training data, the performance of machine - learning models may decline significantly. This is very common in real - world applications because models may face previously unseen data distributions after deployment. To solve this problem, the authors propose a new hypothesis: **Invariant Elementary Distributions (I.E.D.)**. They believe that the real - world distribution can be composed of potential invariant elementary distributions across different domains. This hypothesis implies that there is an invariant structure in the solution space, enabling knowledge to be transferred from known domains to unknown domains. Based on this hypothesis, the authors introduce a new neural - network layer - **Gated Domain Units (GDUs)**. GDUs learn the representation of each potential elementary distribution and create a weighted integrated model during the inference process by comparing the new observations with the similarity of these elementary distributions. This framework is applicable not only to cases with explicit domain information but also to scenarios without clear domain information. ### Main contributions 1. **Proposing the I.E.D. hypothesis**: Hypothesize that both the test and source domains are composed of potential elementary distributions and prove the practicality of this hypothesis in domain generalization. 2. **Developing the GDUs neural - network layer**: Design a modular neural - network layer that can learn the geometric representation of each elementary distribution and adjust the integration weights according to the similarity between new observations and these distributions during inference. 3. **Experimental verification**: Through extensive experiments, verify the effectiveness of the method on the public benchmark dataset WILDS, especially the performance improvement on image, text, and graph data. 4. **Implementation and applicability**: Provide implementations in TensorFlow and PyTorch, which are applicable to different feature extractors (such as ResNet50, DistillBERT, GIN - virtual), facilitating rapid adoption by researchers and practitioners. ### Key technologies in the solution - **Kernel Mean Embedding (KME)**: Used to represent invariant elementary distributions. - **Similarity functions**: Include geometry - based similarities (such as cosine similarity and maximum mean difference) and projection - based similarities. - **Regularization terms**: Ensure that the model can distinguish different domains and minimize the MMD between feature mappings and elementary distributions. Through these technological innovations, the authors provide an effective solution to deal with the distribution shift problem, thereby improving the generalization ability of machine - learning models in unknown domains.