PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning

Weihua Hu,Yiwen Yuan,Zecheng Zhang,Akihiro Nitta,Kaidi Cao,Vid Kocijan,Jure Leskovec,Matthias Fey
2024-04-01
Abstract:We present PyTorch Frame, a PyTorch-based framework for deep learning over multi-modal tabular data. PyTorch Frame makes tabular deep learning easy by providing a PyTorch-based data structure to handle complex tabular data, introducing a model abstraction to enable modular implementation of tabular models, and allowing external foundation models to be incorporated to handle complex columns (e.g., LLMs for text columns). We demonstrate the usefulness of PyTorch Frame by implementing diverse tabular models in a modular way, successfully applying these models to complex multi-modal tabular data, and integrating our framework with PyTorch Geometric, a PyTorch library for Graph Neural Networks (GNNs), to perform end-to-end learning over relational databases.
Machine Learning,Databases
What problem does this paper attempt to address?
The paper proposes a solution to the challenges of deep learning in handling multi-modal tabular data. Existing Gradient Boosting Decision Trees (GBDT) perform well in handling numerical and categorical features, but struggle with effectively processing raw multi-modal features such as text, sequences, and images, and lack end-to-end integration with downstream deep learning models like Graph Neural Networks (GNNs). To address these issues, the paper introduces PyTorch Frame, a framework based on PyTorch designed specifically for multi-modal tabular learning. PyTorch Frame incorporates the following key features: 1. Introduces Tensor Frame, an efficient data structure capable of handling any column in complex tabular data. 2. Provides a model abstraction to modularly implement tabular models, facilitating external base models (e.g. LLMs) to handle complex columns. 3. Can be integrated with PyTorch Geometric to achieve end-to-end learning on relational databases. With this framework, researchers can easily implement various tabular models and successfully apply them to complex multi-modal tabular data. Experiments show that PyTorch Frame performs well in handling traditional numerical/categorical datasets as well as modern tabular data with text columns and relational structures. Despite deep learning models approaching GBDT's performance in certain tasks, GBDT remains a practical choice in conventional tabular learning due to its simplicity and efficiency.