DMTG: One-Shot Differentiable Multi-Task Grouping

Yuan Gao, Shuguo Jiang, Moran Li, Jin-Gang Yu, Gui-Song Xia
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:14747-14762, 2024.

Abstract

We aim to address Multi-Task Learning (MTL) with a large number of tasks by Multi-Task Grouping (MTG). Given $N$ tasks, we propose to simultaneously identify the best task groups from $2^N$ candidates and train the model weights simultaneously in one-shot, with the high-order task-affinity fully exploited. This is distinct from the pioneering methods which sequentially identify the groups and train the model weights, where the group identification often relies on heuristics. As a result, our method not only improves the training efficiency, but also mitigates the objective bias introduced by the sequential procedures that potentially leads to a suboptimal solution. Specifically, we formulate MTG as a fully differentiable pruning problem on an adaptive network architecture determined by an unknown Categorical distribution. To categorize $N$ tasks into $K$ groups (represented by $K$ encoder branches), we initially set up $KN$ task heads, where each branch connects to all $N$ task heads to exploit the high-order task-affinity. Then, we gradually prune the $KN$ heads down to $N$ by learning a relaxed differentiable Categorical distribution, ensuring that each task is exclusively and uniquely categorized into only one branch. Extensive experiments on CelebA and Taskonomy datasets with detailed ablations show the promising performance and efficiency of our method. The codes are available at https://github.com/ethanygao/DMTG.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-gao24h, title = {{DMTG}: One-Shot Differentiable Multi-Task Grouping}, author = {Gao, Yuan and Jiang, Shuguo and Li, Moran and Yu, Jin-Gang and Xia, Gui-Song}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {14747--14762}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/gao24h/gao24h.pdf}, url = {https://proceedings.mlr.press/v235/gao24h.html}, abstract = {We aim to address Multi-Task Learning (MTL) with a large number of tasks by Multi-Task Grouping (MTG). Given $N$ tasks, we propose to simultaneously identify the best task groups from $2^N$ candidates and train the model weights simultaneously in one-shot, with the high-order task-affinity fully exploited. This is distinct from the pioneering methods which sequentially identify the groups and train the model weights, where the group identification often relies on heuristics. As a result, our method not only improves the training efficiency, but also mitigates the objective bias introduced by the sequential procedures that potentially leads to a suboptimal solution. Specifically, we formulate MTG as a fully differentiable pruning problem on an adaptive network architecture determined by an unknown Categorical distribution. To categorize $N$ tasks into $K$ groups (represented by $K$ encoder branches), we initially set up $KN$ task heads, where each branch connects to all $N$ task heads to exploit the high-order task-affinity. Then, we gradually prune the $KN$ heads down to $N$ by learning a relaxed differentiable Categorical distribution, ensuring that each task is exclusively and uniquely categorized into only one branch. Extensive experiments on CelebA and Taskonomy datasets with detailed ablations show the promising performance and efficiency of our method. The codes are available at https://github.com/ethanygao/DMTG.} }
Endnote
%0 Conference Paper %T DMTG: One-Shot Differentiable Multi-Task Grouping %A Yuan Gao %A Shuguo Jiang %A Moran Li %A Jin-Gang Yu %A Gui-Song Xia %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-gao24h %I PMLR %P 14747--14762 %U https://proceedings.mlr.press/v235/gao24h.html %V 235 %X We aim to address Multi-Task Learning (MTL) with a large number of tasks by Multi-Task Grouping (MTG). Given $N$ tasks, we propose to simultaneously identify the best task groups from $2^N$ candidates and train the model weights simultaneously in one-shot, with the high-order task-affinity fully exploited. This is distinct from the pioneering methods which sequentially identify the groups and train the model weights, where the group identification often relies on heuristics. As a result, our method not only improves the training efficiency, but also mitigates the objective bias introduced by the sequential procedures that potentially leads to a suboptimal solution. Specifically, we formulate MTG as a fully differentiable pruning problem on an adaptive network architecture determined by an unknown Categorical distribution. To categorize $N$ tasks into $K$ groups (represented by $K$ encoder branches), we initially set up $KN$ task heads, where each branch connects to all $N$ task heads to exploit the high-order task-affinity. Then, we gradually prune the $KN$ heads down to $N$ by learning a relaxed differentiable Categorical distribution, ensuring that each task is exclusively and uniquely categorized into only one branch. Extensive experiments on CelebA and Taskonomy datasets with detailed ablations show the promising performance and efficiency of our method. The codes are available at https://github.com/ethanygao/DMTG.
APA
Gao, Y., Jiang, S., Li, M., Yu, J. & Xia, G.. (2024). DMTG: One-Shot Differentiable Multi-Task Grouping. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:14747-14762 Available from https://proceedings.mlr.press/v235/gao24h.html.

Related Material