diff --git a/AUTHORS b/AUTHORS index 76e44ef1c..9f33ff551 100644 --- a/AUTHORS +++ b/AUTHORS @@ -133,6 +133,7 @@ Contributors: * Hollis Wu (holi0317) * Antonio Aguilar (crazybolillo) * Andrew M. MacFie (amacfie) + * saucoide Creator: -------- diff --git a/changelog.rst b/changelog.rst index e09a3251e..b915f98e5 100644 --- a/changelog.rst +++ b/changelog.rst @@ -8,6 +8,8 @@ Features: displaying of all Postgres error fields received. * Show Postgres notifications. * Support sqlparse 0.5.x +* Add `--log-file [filename]` cli argument and `\log-file [filename]` special commands to + log to an external file in addition to the normal output Bug fixes: ---------- diff --git a/pgcli/main.py b/pgcli/main.py index 452706697..056a9403f 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -13,6 +13,7 @@ import functools import datetime as dt import itertools +import pathlib import platform from time import time, sleep from typing import Optional @@ -182,6 +183,7 @@ def __init__( auto_vertical_output=False, warn=None, ssh_tunnel_url: Optional[str] = None, + log_file: Optional[str] = None, ): self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt @@ -310,6 +312,11 @@ def __init__( self.ssh_tunnel_url = ssh_tunnel_url self.ssh_tunnel = None + if log_file: + with open(log_file, "a+"): + pass # ensure writeable + self.log_file = log_file + # formatter setup self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) register_new_formatter(self.formatter) @@ -369,6 +376,12 @@ def register_special_commands(self): "\\o [filename]", "Send all query results to file.", ) + self.pgspecial.register( + self.write_to_logfile, + "\\log-file", + "\\log-file [filename]", + "Log all query results to a logfile, in addition to the normal output destination.", + ) self.pgspecial.register( self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" ) @@ -508,6 +521,26 @@ def execute_from_file(self, pattern, **_): explain_mode=self.explain_mode, ) + def write_to_logfile(self, pattern, **_): + if not pattern: + self.log_file = None + message = "Logfile capture disabled" + return [(None, None, None, message, "", True, True)] + + log_file = pathlib.Path(pattern).expanduser().absolute() + + try: + with open(log_file, "a+"): + pass # ensure writeable + except OSError as e: + self.log_file = None + message = str(e) + "\nLogfile capture disabled" + return [(None, None, None, message, "", False, True)] + + self.log_file = str(log_file) + message = 'Writing to file "%s"' % self.log_file + return [(None, None, None, message, "", True, True)] + def write_to_file(self, pattern, **_): if not pattern: self.output_file = None @@ -826,7 +859,7 @@ def execute_command(self, text, handle_closed_connection=True): else: try: if self.output_file and not text.startswith( - ("\\o ", "\\? ", "\\echo ") + ("\\o ", "\\log-file", "\\? ", "\\echo ") ): try: with open(self.output_file, "a", encoding="utf-8") as f: @@ -838,6 +871,23 @@ def execute_command(self, text, handle_closed_connection=True): else: if output: self.echo_via_pager("\n".join(output)) + + # Log to file in addition to normal output + if ( + self.log_file + and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo ")) + and not text.strip() == "" + ): + try: + with open(self.log_file, "a", encoding="utf-8") as f: + click.echo( + dt.datetime.now().isoformat(), file=f + ) # timestamp log + click.echo(text, file=f) + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline + except OSError as e: + click.secho(str(e), err=True, fg="red") except KeyboardInterrupt: pass @@ -1428,6 +1478,11 @@ def echo_via_pager(self, text, color=None): default=None, help="Open an SSH tunnel to the given address and connect to the database from it.", ) +@click.option( + "--log-file", + default=None, + help="Write all queries & output into a file, in addition to the normal output destination.", +) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) def cli( @@ -1453,6 +1508,7 @@ def cli( list_dsn, warn, ssh_tunnel: str, + log_file: str, ): if version: print("Version:", __version__) @@ -1511,6 +1567,7 @@ def cli( auto_vertical_output=auto_vertical_output, warn=warn, ssh_tunnel_url=ssh_tunnel, + log_file=log_file, ) # Choose which ever one has a valid value. diff --git a/tests/test_main.py b/tests/test_main.py index 4ff4e4cfa..3683d491f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,8 @@ import os import platform import re +import tempfile +import datetime from unittest import mock import pytest @@ -333,6 +335,34 @@ def test_qecho_works(executor): assert result == ["asdf"] +@dbtest +def test_logfile_works(executor): + with tempfile.TemporaryDirectory() as tmpdir: + log_file = f"{tmpdir}/tempfile.log" + cli = PGCli(pgexecute=executor, log_file=log_file) + statement = r"\qecho hello!" + cli.execute_command(statement) + with open(log_file, "r") as f: + log_contents = f.readlines() + assert datetime.datetime.fromisoformat(log_contents[0].strip()) + assert log_contents[1].strip() == r"\qecho hello!" + assert log_contents[2].strip() == "hello!" + + +@dbtest +def test_logfile_unwriteable_file(executor): + cli = PGCli(pgexecute=executor) + statement = r"\log-file forbidden.log" + with mock.patch("builtins.open") as mock_open: + mock_open.side_effect = PermissionError( + "[Errno 13] Permission denied: 'forbidden.log'" + ) + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == [ + "[Errno 13] Permission denied: 'forbidden.log'\nLogfile capture disabled" + ] + + @dbtest def test_watch_works(executor): cli = PGCli(pgexecute=executor)