Trompt: Towards a Better Deep Neural Network for Tabular Data

Kuan-Yu Chen,Ping-Han Chiang,Hsin-Rung Chou,Ting-Wei Chen,Tien-Hao Chang
2023-05-31
Abstract:Tabular data is arguably one of the most commonly used data structures in various practical domains, including finance, healthcare and e-commerce. The inherent heterogeneity allows tabular data to store rich information. However, based on a recently published tabular benchmark, we can see deep neural networks still fall behind tree-based models on tabular datasets. In this paper, we propose Trompt--which stands for Tabular Prompt--a novel architecture inspired by prompt learning of language models. The essence of prompt learning is to adjust a large pre-trained model through a set of prompts outside the model without directly modifying the model. Based on this idea, Trompt separates the learning strategy of tabular data into two parts. The first part, analogous to pre-trained models, focus on learning the intrinsic information of a table. The second part, analogous to prompts, focus on learning the variations among samples. Trompt is evaluated with the benchmark mentioned above. The experimental results demonstrate that Trompt outperforms state-of-the-art deep neural networks and is comparable to tree-based models.
Machine Learning
What problem does this paper attempt to address?
### Problems the Paper Aims to Solve This paper aims to address the issue of tabular data performing worse in deep neural networks compared to tree-based models. Specifically: 1. **Background and Motivation**: - Tabular data is widely used in fields such as finance, healthcare, and e-commerce. - Deep neural networks have shown excellent performance in tasks like computer vision and natural language processing, but their performance on tabular data still lags behind tree-based models. - Researchers have been trying to apply deep learning to tabular data, but previous attempts have not performed well on certain specific datasets. 2. **Proposed Method**: - A new architecture named Trompt (Tabular Prompt) is proposed, inspired by prompt learning in natural language processing. - Trompt divides the learning strategy of tabular data into two parts: intrinsic information within the table and variation information between samples. - Experimental results show that Trompt outperforms existing deep neural networks in standard benchmarks and its performance is close to that of tree-based models. 3. **Main Contributions**: - Experiments were conducted on the recognized tabular data benchmark Grinsztajn45. - Trompt achieved state-of-the-art performance in deep neural networks and narrowed the performance gap between deep neural networks and tree-based models. - Extensive empirical studies and ablation tests were conducted to validate the design of Trompt and provide insights for future research directions. Through this work, the paper aims to improve the performance of deep neural networks in handling tabular data and make them comparable to tree-based models.