Need to always compute mean along first dimension (cherry picked from commit 2d371b1997f5fa07fcbbf47e5923d7817d07c6b9)