Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show a more helpful error message when a restart is required for %view #5536

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

import comm
from IPython.core.error import UsageError

from .access_keys import decode_access_key
from .data_explorer_comm import (
Expand Down Expand Up @@ -96,7 +97,14 @@
TextSearchType,
)
from .positron_comm import CommMessage, PositronComm
from .third_party import np_, pd_, pl_
from .third_party import (
RestartRequiredError,
import_pandas,
import_polars,
np_,
pd_,
pl_,
)
from .utils import BackgroundJobQueue, guid

if TYPE_CHECKING:
Expand Down Expand Up @@ -312,6 +320,7 @@ def _match_text_search(params: FilterTextSearch):

def matches(x):
return term in x.lower()

else:

def matches(x):
Expand Down Expand Up @@ -2581,11 +2590,31 @@ class PyArrowView(DataExplorerTableView):


def _is_pandas(table):
return pd_ is not None and isinstance(table, (pd_.DataFrame, pd_.Series))
pandas = import_pandas()
if pandas is not None and isinstance(table, (pandas.DataFrame, pandas.Series)):
# If pandas was installed after the kernel was started, pd_ will still be None.
# Raise an error to inform the user to restart the kernel.
if pd_ is None:
raise RestartRequiredError(
"Pandas was installed after the session started. Please restart the session to "
+ "view the table in the Data Explorer."
)
return True
return False


def _is_polars(table):
return pl_ is not None and isinstance(table, (pl_.DataFrame, pl_.Series))
polars = import_polars()
if polars is not None and isinstance(table, (polars.DataFrame, polars.Series)):
# If polars was installed after the kernel was started, pl_ will still be None.
# Raise an error to inform the user to restart the kernel.
if pl_ is None:
raise RestartRequiredError(
"Polars was installed after the session started. Please restart the session to "
+ "view the table."
)
return True
return False


def _get_table_view(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .patch.holoviews import set_holoviews_extension
from .plots import PlotsService
from .session_mode import SessionMode
from .third_party import RestartRequiredError
from .ui import UiService
from .utils import BackgroundJobQueue, JsonRecord, get_qualname
from .variables import VariablesService
Expand Down Expand Up @@ -168,6 +169,8 @@ def view(self, line: str) -> None:
)
except TypeError:
raise UsageError(f"cannot view object of type '{get_qualname(obj)}'")
except RestartRequiredError as error:
raise UsageError(*error.args)

@magic_arguments.magic_arguments()
@magic_arguments.argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from datetime import datetime
from decimal import Decimal
from io import StringIO
from typing import Any, Dict, List, Optional, Type, cast
from typing import Any, Dict, List, Optional, Type, Union, cast

import numpy as np
import pandas as pd
import polars as pl
import pytest
import pytz

from .. import data_explorer
from .._vendor.pydantic import BaseModel
from ..access_keys import encode_access_key
from ..data_explorer import (
Expand Down Expand Up @@ -49,6 +50,7 @@
RowFilterTypeSupportStatus,
SupportStatus,
)
from ..third_party import RestartRequiredError
from ..utils import guid
from .conftest import DummyComm, PositronShell
from .test_variables import BIG_ARRAY_LENGTH
Expand Down Expand Up @@ -295,6 +297,30 @@ def test_register_table_with_variable_path(de_service: DataExplorerService):
assert table_view.state.name == title


@pytest.mark.parametrize(
("table", "import_name", "title"),
[(pd.DataFrame({}), "pd_", "Pandas"), (pl.DataFrame({}), "pl_", "Polars")],
)
def test_register_table_after_installing_dependency(
table: Union[pd.DataFrame, pl.DataFrame],
import_name: str,
title: str,
de_service: DataExplorerService,
monkeypatch,
):
# Patch the module (e.g. third_party.pd_) to None. Since these packages are really is installed
# during tests, this simulates the case where the user installs the package after the kernel
# starts, therefore the third_party attribute (e.g. pd_) is None but the corresponding import
# function (third_party.import_pandas()) returns the module.
# See https://github.com/posit-dev/positron/issues/5535.
monkeypatch.setattr(data_explorer, import_name, None)

with pytest.raises(
RestartRequiredError, match=f"^{title} was installed after the session started."
):
de_service.register_table(table, "test_table")


def test_shutdown(de_service: DataExplorerService):
df = pd.DataFrame({"a": [1, 2, 3, 4, 5]})
de_service.register_table(df, "t1", comm_id=guid())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,53 @@
# checking.


def _get_numpy():
class RestartRequiredError(Exception):
"""Raised when a restart is required to load a third party package."""

pass


def import_numpy():
try:
import numpy
except ImportError:
numpy = None
return numpy


def _get_pandas():
def import_pandas():
try:
import pandas
except ImportError:
pandas = None
return pandas


def _get_polars():
def import_polars():
try:
import polars
except ImportError:
polars = None
return polars


def _get_torch():
def import_torch():
try:
import torch # type: ignore [reportMissingImports] for 3.12
except ImportError:
torch = None
return torch


def _get_pyarrow():
def import_pyarrow():
try:
import pyarrow # type: ignore [reportMissingImports] for 3.12
except ImportError:
pyarrow = None
return pyarrow


def _get_sqlalchemy():
def import_sqlalchemy():
try:
import sqlalchemy
except ImportError:
Expand All @@ -59,11 +65,12 @@ def _get_sqlalchemy():

# Currently, pyright only correctly infers the types below as `Optional` if we set their values
# using function calls.
np_ = _get_numpy()
pa_ = _get_pyarrow()
pd_ = _get_pandas()
pl_ = _get_polars()
torch_ = _get_torch()
sqlalchemy_ = _get_sqlalchemy()
np_ = import_numpy()
pa_ = import_pyarrow()
pd_ = import_pandas()
pl_ = import_polars()
torch_ = import_torch()
sqlalchemy_ = import_sqlalchemy()


__all__ = ["np_", "pa_", "pd_", "pl_", "torch_", "sqlalchemy_"]
Loading