Skip to content

Commit

Permalink
Uses the default python logging package instead of absl.logging
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698517173
  • Loading branch information
lukebaumann authored and copybara-github committed Nov 20, 2024
1 parent 15441fe commit 2a494e4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 17 deletions.
10 changes: 6 additions & 4 deletions pathwaysutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
"""Package of Pathways-on-Cloud utilities."""

import datetime
import logging
import os

from absl import logging
import jax
from pathwaysutils import cloud_logging
from pathwaysutils import profiling
from pathwaysutils import proxy_backend
from pathwaysutils.persistence import pathways_orbax_handler


logger = logging.getLogger(__name__)

# A new PyPI release will be pushed every time `__version__` is increased.
# When changing this, also update the CHANGELOG.md.
__version__ = "v0.0.7"
Expand All @@ -50,7 +52,7 @@ def _is_persistence_enabled():


if _is_pathways_used():
logging.debug(
logger.debug(
"pathwaysutils: Detected Pathways-on-Cloud backend. Applying changes."
)
proxy_backend.register_backend_factory()
Expand All @@ -68,9 +70,9 @@ def _is_persistence_enabled():
try:
cloud_logging.setup()
except OSError as e:
logging.debug("pathwaysutils: Failed to set up cloud logging.")
logger.debug("pathwaysutils: Failed to set up cloud logging.")
else:
logging.debug(
logger.debug(
"pathwaysutils: Did not detect Pathways-on-Cloud backend. No changes"
" applied."
)
9 changes: 6 additions & 3 deletions pathwaysutils/persistence/pathways_orbax_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
import collections
import datetime
import functools
import logging
import typing
from typing import Optional, Sequence

from absl import logging
import jax
from orbax.checkpoint import future
from orbax.checkpoint import type_handlers
from pathwaysutils.persistence import helper


logger = logging.getLogger(__name__)

ParamInfo = type_handlers.ParamInfo
SaveArgs = type_handlers.SaveArgs
RestoreArgs = type_handlers.RestoreArgs
Expand Down Expand Up @@ -121,7 +124,7 @@ async def deserialize(
mesh_axes.append(sharding.spec)
shardings.append(sharding)
if arg.global_shape is None or arg.dtype is None:
logging.warning(
logger.warning(
'Shape or dtype not provided for restoration. Provide these'
' properties for improved performance.'
)
Expand Down Expand Up @@ -180,7 +183,7 @@ def register_pathways_handlers(
read_timeout: Optional[datetime.timedelta] = None,
):
"""Function that must be called before saving or restoring with Pathways."""
logging.debug(
logger.debug(
'Registering CloudPathwaysArrayHandler (Pathways Persistence API).'
)
type_handlers.register_type_handler(
Expand Down
20 changes: 10 additions & 10 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
"""Profiling utilites."""

import dataclasses
import logging
import threading
import time

from absl import logging
import fastapi
import jax
from jax import numpy as jnp
from pathwaysutils import plugin_executable
import uvicorn

logger = logging.getLogger(__name__)


class _ProfileState:
def __init__(self):
Expand Down Expand Up @@ -100,7 +102,7 @@ def start_server(port: int):
port : The port to start the server on.
"""
def server_loop(port: int):
logging.debug("Starting JAX profiler server on port %s", port)
logger.debug("Starting JAX profiler server on port %s", port)
app = fastapi.FastAPI()

@dataclasses.dataclass
Expand All @@ -110,8 +112,8 @@ class ProfilingConfig:

@app.post("/profiling")
async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
logging.debug("Capturing profiling data for %s ms", pc.duration_ms)
logging.debug("Writing profiling data to %s", pc.repository_path)
logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
logger.debug("Writing profiling data to %s", pc.repository_path)
jax.profiler.start_trace(pc.repository_path)
time.sleep(pc.duration_ms / 1e3)
jax.profiler.stop_trace()
Expand Down Expand Up @@ -156,27 +158,25 @@ def start_trace_patch(
create_perfetto_link: bool = False, # pylint: disable=unused-argument
create_perfetto_trace: bool = False, # pylint: disable=unused-argument
) -> None:
logging.debug("jax.profile.start_trace patched with pathways' start_trace")
logger.debug("jax.profile.start_trace patched with pathways' start_trace")
return start_trace(log_dir)

jax.profiler.start_trace = start_trace_patch

def stop_trace_patch() -> None:
logging.debug("jax.profile.stop_trace patched with pathways' stop_trace")
logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
return stop_trace()

jax.profiler.stop_trace = stop_trace_patch

def start_server_patch(port: int):
logging.debug(
"jax.profile.start_server patched with pathways' start_server"
)
logger.debug("jax.profile.start_server patched with pathways' start_server")
return start_server(port)

jax.profiler.start_server = start_server_patch

def stop_server_patch():
logging.debug("jax.profile.stop_server patched with pathways' stop_server")
logger.debug("jax.profile.stop_server patched with pathways' stop_server")
return stop_server()

jax.profiler.stop_server = stop_server_patch

0 comments on commit 2a494e4

Please sign in to comment.