Fine-Tuning Pre-trained Language Models for Robust Causal Representation Learning

Jialin Yu,Yuxiang Zhou,Yulan He,Nevin L. Zhang,Ricardo Silva
2024-10-18
Abstract:The fine-tuning of pre-trained language models (PLMs) has been shown to be effective across various domains. By using domain-specific supervised data, the general-purpose representation derived from PLMs can be transformed into a domain-specific representation. However, these methods often fail to generalize to out-of-domain (OOD) data due to their reliance on non-causal representations, often described as spurious features. Existing methods either make use of adjustments with strong assumptions about lack of hidden common causes, or mitigate the effect of spurious features using multi-domain data. In this work, we investigate how fine-tuned pre-trained language models aid generalizability from single-domain scenarios under mild assumptions, targeting more general and practical real-world scenarios. We show that a robust representation can be derived through a so-called causal front-door adjustment, based on a decomposition assumption, using fine-tuned representations as a source of data augmentation. Comprehensive experiments in both synthetic and real-world settings demonstrate the superior generalizability of the proposed method compared to existing approaches. Our work thus sheds light on the domain generalization problem by introducing links between fine-tuning and causal mechanisms into representation learning.
Machine Learning,Computation and Language
What problem does this paper attempt to address?
### The problems the paper attempts to solve This paper aims to solve the problem of insufficient generalization ability when pre - trained language models (PLMs) are fine - tuned in specific domains and then face out - of - distribution (OOD) data. Specifically, existing methods usually rely on non - causal representations (i.e., spurious features), which exist in the training data but may not exist in the test data, leading to a decline in the performance of the model on OOD data. The paper proposes a new method to improve the OOD generalization ability of the model in a single - domain scenario by using the fine - tuned pre - trained language model to learn robust causal representations. ### Background and motivation 1. **Fine - tuning of pre - trained language models**: - Pre - trained language models (such as BERT) are trained on large - scale corpora to generate context - aware representations and are widely used in natural language understanding tasks. - Fine - tuning these models to adapt to specific tasks can improve performance, but usually relies on non - causal representations, which take advantage of spurious correlations in the training data, resulting in poor generalization ability on OOD data. 2. **Limitations of existing methods**: - Existing methods either assume that there are no hidden common causes or use multi - domain data to mitigate the impact of spurious features. - Multi - domain data is often difficult to obtain in natural language processing, and data augmentation methods are more complex in language processing than in image processing. ### Research questions The paper proposes the following research questions: - How can pre - trained language models (PLMs) be used to learn robust causal representations to enhance OOD generalization ability? ### Method overview 1. **Causal perspective analysis**: - The standard supervised fine - tuning estimator \( p(y|x) \) fails in OOD scenarios because the distributions of training and test data are different. - This problem can be solved by the causal estimator \( p(y|\text{do}(x)) \), where \( \text{do}(x) \) represents an intervention operation, fixing \( X = x \). 2. **Key assumptions**: - **Decomposition assumption**: Each input text \( X \) can be decomposed into a causal latent variable \( C \) and a spurious latent variable \( S \), where \( C \) is the sole causal parent node of the label \( Y \). - **Paired representation assumption**: For each input text \( X \), its representations \( R_0 \) and \( R_1 \) can be obtained from two different environments, where the causal factor \( C \) remains unchanged while the spurious factor \( S \) varies. - **Local feature assumption**: Word - level features \( \Phi \) for each input text can be obtained for free from the fine - tuned model, and these features can be used to predict the label \( Y \). - **Sufficient mediation assumption**: The causal influence of local features \( \Phi \) on the label \( Y \) is only transmitted through a set of variables in \( C \). 3. **Identifying causal effects**: - Using the above assumptions, causal features can be identified by learning mapping functions from \( R_0 \) and \( R_1 \) to \( C \). - Through the front - door adjustment method, the causal estimator \( p(y|\text{do}(x)) \) is constructed. ### Experimental verification 1. **Experimental setup**: - Experiments are carried out using semi - synthetic datasets and real - world datasets. - Multiple baseline methods are compared, including standard fine - tuning (SFT), weight - averaging (WSA), and parameter interpolation (WISE), etc. 2. **Experimental results**: - On the semi - synthetic dataset, the proposed CTL method significantly outperforms other baseline methods on OOD data. - In particular, in the case of large distribution changes, the CTL method shows stronger robustness. ### Conclusion By introducing a causal mechanism and using pre - trained language models to learn robust causal representations in a single - domain scenario, the paper effectively improves the generalization ability of the model on OOD data.