Jaxpruner: A Concise Library for Sparsity Research

Joo Hyung Lee, Wonpyo Park, Nicole Elyse Mitchell, Jonathan Pilault, Johan Samir Obando Ceron, Han-Byul Kim, Namhoon Lee, Elias Frantar, Yun Long, Amir Yazdanbakhsh, Woohyun Han, Shivani Agrawal, Suvinay Subramanian, Xin Wang, Sheng-Chun Kao, Xingyao Zhang, Trevor Gale, Aart J.C. Bik, Milen Ferev, Zhonglin Han, Hong-Seok Kim, Yann Dauphin, Gintare Karolina Dziugaite, Pablo Samuel Castro, Utku Evci
Conference on Parsimony and Learning, PMLR 234:515-528, 2024.

Abstract

This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks. Jaxpruner is hosted at github.com/google-research/jaxpruner

Cite this Paper


BibTeX
@InProceedings{pmlr-v234-lee24a, title = {Jaxpruner: A Concise Library for Sparsity Research}, author = {Lee, Joo Hyung and Park, Wonpyo and Mitchell, Nicole Elyse and Pilault, Jonathan and Ceron, Johan Samir Obando and Kim, Han-Byul and Lee, Namhoon and Frantar, Elias and Long, Yun and Yazdanbakhsh, Amir and Han, Woohyun and Agrawal, Shivani and Subramanian, Suvinay and Wang, Xin and Kao, Sheng-Chun and Zhang, Xingyao and Gale, Trevor and Bik, Aart J.C. and Ferev, Milen and Han, Zhonglin and Kim, Hong-Seok and Dauphin, Yann and Dziugaite, Gintare Karolina and Castro, Pablo Samuel and Evci, Utku}, booktitle = {Conference on Parsimony and Learning}, pages = {515--528}, year = {2024}, editor = {Chi, Yuejie and Dziugaite, Gintare Karolina and Qu, Qing and Wang, Atlas Wang and Zhu, Zhihui}, volume = {234}, series = {Proceedings of Machine Learning Research}, month = {03--06 Jan}, publisher = {PMLR}, pdf = {https://proceedings.mlr.press/v234/lee24a/lee24a.pdf}, url = {https://proceedings.mlr.press/v234/lee24a.html}, abstract = {This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks. Jaxpruner is hosted at github.com/google-research/jaxpruner} }
Endnote
%0 Conference Paper %T Jaxpruner: A Concise Library for Sparsity Research %A Joo Hyung Lee %A Wonpyo Park %A Nicole Elyse Mitchell %A Jonathan Pilault %A Johan Samir Obando Ceron %A Han-Byul Kim %A Namhoon Lee %A Elias Frantar %A Yun Long %A Amir Yazdanbakhsh %A Woohyun Han %A Shivani Agrawal %A Suvinay Subramanian %A Xin Wang %A Sheng-Chun Kao %A Xingyao Zhang %A Trevor Gale %A Aart J.C. Bik %A Milen Ferev %A Zhonglin Han %A Hong-Seok Kim %A Yann Dauphin %A Gintare Karolina Dziugaite %A Pablo Samuel Castro %A Utku Evci %B Conference on Parsimony and Learning %C Proceedings of Machine Learning Research %D 2024 %E Yuejie Chi %E Gintare Karolina Dziugaite %E Qing Qu %E Atlas Wang Wang %E Zhihui Zhu %F pmlr-v234-lee24a %I PMLR %P 515--528 %U https://proceedings.mlr.press/v234/lee24a.html %V 234 %X This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks. Jaxpruner is hosted at github.com/google-research/jaxpruner
APA
Lee, J.H., Park, W., Mitchell, N.E., Pilault, J., Ceron, J.S.O., Kim, H., Lee, N., Frantar, E., Long, Y., Yazdanbakhsh, A., Han, W., Agrawal, S., Subramanian, S., Wang, X., Kao, S., Zhang, X., Gale, T., Bik, A.J., Ferev, M., Han, Z., Kim, H., Dauphin, Y., Dziugaite, G.K., Castro, P.S. & Evci, U.. (2024). Jaxpruner: A Concise Library for Sparsity Research. Conference on Parsimony and Learning, in Proceedings of Machine Learning Research 234:515-528 Available from https://proceedings.mlr.press/v234/lee24a.html.

Related Material