Abstract:This paper investigates group distributionally robust optimization (GDRO) with the goal of learning a model that performs well over $m$ different distributions. First, we formulate GDRO as a stochastic convex-concave saddle-point problem, which is then solved by stochastic mirror descent (SMD) with $m$ samples in each iteration, and attain a nearly optimal sample complexity. To reduce the number of samples required in each round from $m$ to 1, we cast GDRO as a two-player game, where one player conducts SMD and the other executes an online algorithm for non-oblivious multi-armed bandits, maintaining the same sample complexity. Next, we extend GDRO to address scenarios involving imbalanced data and heterogeneous distributions. In the first scenario, we introduce a weighted variant of GDRO, enabling distribution-dependent convergence rates that rely on the number of samples from each distribution. We design two strategies to meet the sample budget: one integrates non-uniform sampling into SMD, and the other employs the stochastic mirror-prox algorithm with mini-batches, both of which deliver faster rates for distributions with more samples. In the second scenario, we propose to optimize the average top-$k$ risk instead of the maximum risk, thereby mitigating the impact of outlier distributions. Similar to the case of vanilla GDRO, we develop two stochastic approaches: one uses $m$ samples per iteration via SMD, and the other consumes $k$ samples per iteration through an online algorithm for non-oblivious combinatorial semi-bandits.
What problem does this paper attempt to address?
The main problem that this paper attempts to solve is the optimization of model robustness on multiple different distributions. Specifically, it aims to learn a model that can perform well on m different distributions. To achieve this goal, the author introduced the concept of Group Distributionally Robust Optimization (GDRO) and proposed a series of stochastic approximation - based methods to solve the GDRO problem.
### Description of the Main Problem
1. **Limitations in Classical Machine Learning**:
- In classical statistical machine learning, the goal is to minimize the risk with respect to a fixed distribution \( P_0 \):
\[
\min_{w \in W} R_0(w)=\mathbb{E}_{z \sim P_0}[\ell(w;z)]
\]
- However, models trained on this single distribution may have the following problems:
- Higher error on minority subsets.
- Significantly decreased performance on different distributions.
2. **Introduction of Distributionally Robust Optimization (DRO)**:
- DRO improves the robustness of the model by minimizing the worst - case risk:
\[
\min_{w \in W} \sup_{P \in S(P_0)} \mathbb{E}_{z \sim P}[\ell(w;z)]
\]
- Here, \( S(P_0) \) is an uncertainty set around \( P_0 \).
3. **Group Distributionally Robust Optimization (GDRO)**:
- GDRO further extends DRO by considering a finite number of different distributions \( P_1, P_2,\ldots, P_m \) and optimizing the maximum risk:
\[
L_{\text{max}}(w)=\max_{i \in [m]} R_i(w)=\mathbb{E}_{z \sim P_i}[\ell(w;z)]
\]
- Mathematically, GDRO can be represented as a stochastic convex - concave saddle - point problem:
\[
\min_{w \in W} \max_{q \in \Delta^m} \phi(w,q)=\sum_{i = 1}^m q_i R_i(w)
\]
where \( \Delta^m=\{q \in \mathbb{R}^m|q \geq 0,\sum_{i = 1}^m q_i = 1\} \) is the \((m - 1)\)-dimensional simplex.
### Solutions
1. **Stochastic Mirror Descent (SMD)**:
- The author first transforms the GDRO problem into a stochastic convex - concave saddle - point problem and uses SMD to solve it. In each iteration, a sample is drawn from each distribution to construct an unbiased estimator and update the model parameters and weights.
- This method achieves a convergence rate of \( O(\sqrt{\frac{\log m}{T}}) \) in expectation and with high probability respectively, and the sample complexity is \( O(m\log m/\epsilon^2) \).
2. **Reducing the Number of Samples per Round**:
- To reduce the number of samples required per round from \( m \) to 1, the author proposes two methods:
- Use the non - forgetting multi - armed bandit algorithm (Exp3 - IX) to update the weight \( q \) and draw only one sample in each round.
- This method also achieves a convergence rate of \( O(\sqrt{\frac{m\log m}{T}}) \) and the sample complexity is \( O(m\log m/\epsilon^2) \).
3. **Extension to Imbalanced Data and Heterogeneous Distributions**:
- For imbalanced data, the author introduced weighted GDRO, which adjusts the weights according to the sample budget of each distribution to achieve faster convergence.