Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709855044
  • Loading branch information
Pathways-on-Cloud Team authored and copybara-github committed Dec 26, 2024
1 parent 2a494e4 commit f3f4299
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions pathwaysutils/test/google_internal/persistence_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Persistence tests that can only run in google3."""

from absl import flags
import jax

from google3.learning.pathways.ifrt.proxy.jax.tests import register_jax_grpc_backend_for_testing # pylint: disable=unused-import
from absl.testing import absltest


_JAX_BACKEND_TARGET = flags.DEFINE_string(
"jax_backend_target",
"ifrt_pathways",
"Jax backend target to use.",
)

_JAX_PLATFORMS = flags.DEFINE_string(
"jax_platforms",
"proxy",
"Jax platforms to use.",
)

# set JAX_ALLOW_UNUSED_TPUS to avoid AssertionError: The host has 4 TPU chips
# but TPU support is not linked into JAX. You should add a BUILD dependency
# on //learning/brain/research/jax:tpu_support.
#
# This error happens because we are
# //learning/pathways/data_parallel:tpu_support instead of
# //learning/brain/research/jax:tpu_support
flags.FLAGS.jax_allow_unused_tpus = True


class PersistenceTest(absltest.TestCase):

def test_devices_can_be_fetched_from_proxy_backend(self):
devices = jax.devices("proxy")
self.assertNotEmpty(devices)


if __name__ == "__main__":
absltest.main()

0 comments on commit f3f4299

Please sign in to comment.