diff --git a/AUTHORS b/AUTHORS index 922c241f1..76e44ef1c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -132,6 +132,7 @@ Contributors: * Sharon Yogev (sharonyogev) * Hollis Wu (holi0317) * Antonio Aguilar (crazybolillo) + * Andrew M. MacFie (amacfie) Creator: -------- diff --git a/changelog.rst b/changelog.rst index cb70bb8d6..d9c9d50ac 100644 --- a/changelog.rst +++ b/changelog.rst @@ -4,6 +4,7 @@ Upcoming Features: --------- * Support `PGAPPNAME` as an environment variable and `--application-name` as a command line argument. +* Show Postgres notifications Bug fixes: ---------- diff --git a/pgcli/main.py b/pgcli/main.py index cfa1c970a..bb9c9c8ec 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -73,7 +73,7 @@ from getpass import getuser -from psycopg import OperationalError, InterfaceError +from psycopg import OperationalError, InterfaceError, Notify from psycopg.conninfo import make_conninfo, conninfo_to_dict from collections import namedtuple @@ -128,6 +128,15 @@ class PgCliQuitError(Exception): pass +def notify_callback(notify: Notify): + click.secho( + 'Notification received on channel "{}" (PID {}):\n{}'.format( + notify.channel, notify.pid, notify.payload + ), + fg="green", + ) + + class PGCli: default_prompt = "\\u@\\h:\\d> " max_len_prompt = 30 @@ -660,7 +669,16 @@ def should_ask_for_password(exc): # prompt for a password (no -w flag), prompt for a passwd and try again. try: try: - pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) + pgexecute = PGExecute( + database, + user, + passwd, + host, + port, + dsn, + notify_callback, + **kwargs, + ) except (OperationalError, InterfaceError) as e: if should_ask_for_password(e): passwd = click.prompt( @@ -670,7 +688,14 @@ def should_ask_for_password(exc): type=str, ) pgexecute = PGExecute( - database, user, passwd, host, port, dsn, **kwargs + database, + user, + passwd, + host, + port, + dsn, + notify_callback, + **kwargs, ) else: raise e diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index f7eb6f865..e09175728 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -167,6 +167,7 @@ def __init__( host=None, port=None, dsn=None, + notify_callback=None, **kwargs, ): self._conn_params = {} @@ -179,6 +180,7 @@ def __init__( self.port = None self.server_version = None self.extra_args = None + self.notify_callback = notify_callback self.connect(database, user, password, host, port, dsn, **kwargs) self.reset_expanded = None @@ -237,6 +239,9 @@ def connect( self.conn = conn self.conn.autocommit = True + if self.notify_callback is not None: + self.conn.add_notify_handler(self.notify_callback) + # When we connect using a DSN, we don't really know what db, # user, etc. we connected to. Let's read it. # Note: moved this after setting autocommit because of #664. diff --git a/tests/conftest.py b/tests/conftest.py index 33cddf247..e50f1fe07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ db_connection, drop_tables, ) +import pgcli.main import pgcli.pgexecute @@ -37,6 +38,7 @@ def executor(connection): password=POSTGRES_PASSWORD, port=POSTGRES_PORT, dsn=None, + notify_callback=pgcli.main.notify_callback, ) diff --git a/tests/test_main.py b/tests/test_main.py index 0aeba80ea..de62263a9 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,5 +1,6 @@ import os import platform +import re from unittest import mock import pytest @@ -13,6 +14,7 @@ obfuscate_process_password, duration_in_words, format_output, + notify_callback, PGCli, OutputSettings, COLOR_CODE_REGEX, @@ -432,6 +434,7 @@ def test_pg_service_file(tmpdir): "b_host", "5435", "", + notify_callback, application_name="pgcli", ) del os.environ["PGPASSWORD"] @@ -487,7 +490,7 @@ def test_application_name_db_uri(tmpdir): cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri("postgres://bar@baz.com/?application_name=cow") mock_pgexecute.assert_called_with( - "bar", "bar", "", "baz.com", "", "", application_name="cow" + "bar", "bar", "", "baz.com", "", "", notify_callback, application_name="cow" ) @@ -514,3 +517,23 @@ def test_application_name_db_uri(tmpdir): ) def test_duration_in_words(duration_in_seconds, words): assert duration_in_words(duration_in_seconds) == words + + +@dbtest +def test_notifications(executor): + run(executor, "listen chan1") + + with mock.patch("pgcli.main.click.secho") as mock_secho: + run(executor, "notify chan1, 'testing1'") + mock_secho.assert_called() + arg = mock_secho.call_args_list[0].args[0] + assert re.match( + r'Notification received on channel "chan1" \(PID \d+\):\ntesting1', + arg, + ) + + run(executor, "unlisten chan1") + + with mock.patch("pgcli.main.click.secho") as mock_secho: + run(executor, "notify chan1, 'testing2'") + mock_secho.assert_not_called() diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py index ae865f4ab..983212b8a 100644 --- a/tests/test_ssh_tunnel.py +++ b/tests/test_ssh_tunnel.py @@ -6,7 +6,7 @@ from click.testing import CliRunner from sshtunnel import SSHTunnelForwarder -from pgcli.main import cli, PGCli +from pgcli.main import cli, notify_callback, PGCli from pgcli.pgexecute import PGExecute @@ -61,6 +61,7 @@ def test_ssh_tunnel( "127.0.0.1", pgcli.ssh_tunnel.local_bind_ports[0], "", + notify_callback, ) mock_ssh_tunnel_forwarder.reset_mock() mock_pgexecute.reset_mock() @@ -96,6 +97,7 @@ def test_ssh_tunnel( "127.0.0.1", pgcli.ssh_tunnel.local_bind_ports[0], "", + notify_callback, ) mock_ssh_tunnel_forwarder.reset_mock() mock_pgexecute.reset_mock()