High-Performance Tensor-Train Primitives Using GPU Tensor Cores

Xiao-Yang Liu,Hao Hong,Zeliang Zhang,Weiqin Tong,Jean Kossaifi,Xiaodong Wang,Anwar Walid
DOI: https://doi.org/10.1109/tc.2024.3441831
IF: 3.183
2024-10-12
IEEE Transactions on Computers
Abstract:Learning tensor-train (TT) structure (a.k.a matrix product state (MPS) representation) from large-scale high-dimensional data has been a common task in big data analysis, deep learning, and quantum machine learning. However, tensor-train algorithms are compute-intensive, which hinders their real-world applications. In this paper, we present high-performance tensor-train primitives using GPU tensor cores and demonstrate three applications. First, we use GPU tensor cores to optimize tensor-train primitives, including tensor contraction, singular value decomposition, and data transfer and computing. Second, we utilize the optimized primitives to accelerate tensor-train decomposition algorithms for big data analysis. Further, we propose a shard mode for high-order tensor computations on multiple GPUs. Third, we apply the optimized primitives to accelerate the tensor-train layer for compressing deep neural networks. Last, we utilize the optimized primitives to accelerate a quantum machine learning algorithm called Density Matrix Renormalization Group (DMRG). In performance evaluations, our third-order TT tensor decomposition achieves up to 3.34× and 6.91× speedups over two popular libraries (namely T3F and tntorch) on an A100 GPU, respectively. The proposed sixth-order tensor-train decomposition achieves up to a speedup of 5.01× over T3F on multiple A100 GPUs. Our tensor-train layer for a fully connected neural network achieves a compression ratio of 65.3× at the cost of 0.3% drop in accuracy and a speedup of 1.53× over a PyTorch implementation on CUDA cores. The optimized DMRG algorithm achieves up to a speedup of 14.0× over TensorNetwork, indicating the potential of the optimized tensor primitives for the classical simulation of quantum machine learning algorithms.
engineering, electrical & electronic,computer science, hardware & architecture
What problem does this paper attempt to address?