Skip to content

Commit

Permalink
Add --log-file & \log-file option to always capture output (#1461)
Browse files Browse the repository at this point in the history
* Add --log-file & \log-file option to always capture output

Currently outputting to a file via \o disables the console output
This patch adds a `--log-file` cli arg with similar behavior to `psql`* and a \log-file
special command to enable/disable it from the console

*https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-OPTION-LOG-FILE

* switch to use context manager

Co-authored-by: Irina Truong <[email protected]>

* switch to use context manager

Co-authored-by: Irina Truong <[email protected]>

* use isoformat explicitly

Co-authored-by: Irina Truong <[email protected]>

* change test to use a mock, update changelog & authors

* reformat

* black

---------

Co-authored-by: Irina Truong <[email protected]>
  • Loading branch information
saucoide and j-bennet authored May 1, 2024
1 parent c6c5f04 commit ce7f76a
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 1 deletion.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ Contributors:
* Hollis Wu (holi0317)
* Antonio Aguilar (crazybolillo)
* Andrew M. MacFie (amacfie)
* saucoide

Creator:
--------
Expand Down
2 changes: 2 additions & 0 deletions changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Features:
displaying of all Postgres error fields received.
* Show Postgres notifications.
* Support sqlparse 0.5.x
* Add `--log-file [filename]` cli argument and `\log-file [filename]` special commands to
log to an external file in addition to the normal output

Bug fixes:
----------
Expand Down
59 changes: 58 additions & 1 deletion pgcli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import functools
import datetime as dt
import itertools
import pathlib
import platform
from time import time, sleep
from typing import Optional
Expand Down Expand Up @@ -182,6 +183,7 @@ def __init__(
auto_vertical_output=False,
warn=None,
ssh_tunnel_url: Optional[str] = None,
log_file: Optional[str] = None,
):
self.force_passwd_prompt = force_passwd_prompt
self.never_passwd_prompt = never_passwd_prompt
Expand Down Expand Up @@ -310,6 +312,11 @@ def __init__(
self.ssh_tunnel_url = ssh_tunnel_url
self.ssh_tunnel = None

if log_file:
with open(log_file, "a+"):
pass # ensure writeable
self.log_file = log_file

# formatter setup
self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
register_new_formatter(self.formatter)
Expand Down Expand Up @@ -369,6 +376,12 @@ def register_special_commands(self):
"\\o [filename]",
"Send all query results to file.",
)
self.pgspecial.register(
self.write_to_logfile,
"\\log-file",
"\\log-file [filename]",
"Log all query results to a logfile, in addition to the normal output destination.",
)
self.pgspecial.register(
self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
)
Expand Down Expand Up @@ -508,6 +521,26 @@ def execute_from_file(self, pattern, **_):
explain_mode=self.explain_mode,
)

def write_to_logfile(self, pattern, **_):
if not pattern:
self.log_file = None
message = "Logfile capture disabled"
return [(None, None, None, message, "", True, True)]

log_file = pathlib.Path(pattern).expanduser().absolute()

try:
with open(log_file, "a+"):
pass # ensure writeable
except OSError as e:
self.log_file = None
message = str(e) + "\nLogfile capture disabled"
return [(None, None, None, message, "", False, True)]

self.log_file = str(log_file)
message = 'Writing to file "%s"' % self.log_file
return [(None, None, None, message, "", True, True)]

def write_to_file(self, pattern, **_):
if not pattern:
self.output_file = None
Expand Down Expand Up @@ -826,7 +859,7 @@ def execute_command(self, text, handle_closed_connection=True):
else:
try:
if self.output_file and not text.startswith(
("\\o ", "\\? ", "\\echo ")
("\\o ", "\\log-file", "\\? ", "\\echo ")
):
try:
with open(self.output_file, "a", encoding="utf-8") as f:
Expand All @@ -838,6 +871,23 @@ def execute_command(self, text, handle_closed_connection=True):
else:
if output:
self.echo_via_pager("\n".join(output))

# Log to file in addition to normal output
if (
self.log_file
and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo "))
and not text.strip() == ""
):
try:
with open(self.log_file, "a", encoding="utf-8") as f:
click.echo(
dt.datetime.now().isoformat(), file=f
) # timestamp log
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
except OSError as e:
click.secho(str(e), err=True, fg="red")
except KeyboardInterrupt:
pass

Expand Down Expand Up @@ -1428,6 +1478,11 @@ def echo_via_pager(self, text, color=None):
default=None,
help="Open an SSH tunnel to the given address and connect to the database from it.",
)
@click.option(
"--log-file",
default=None,
help="Write all queries & output into a file, in addition to the normal output destination.",
)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
def cli(
Expand All @@ -1453,6 +1508,7 @@ def cli(
list_dsn,
warn,
ssh_tunnel: str,
log_file: str,
):
if version:
print("Version:", __version__)
Expand Down Expand Up @@ -1511,6 +1567,7 @@ def cli(
auto_vertical_output=auto_vertical_output,
warn=warn,
ssh_tunnel_url=ssh_tunnel,
log_file=log_file,
)

# Choose which ever one has a valid value.
Expand Down
30 changes: 30 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import platform
import re
import tempfile
import datetime
from unittest import mock

import pytest
Expand Down Expand Up @@ -333,6 +335,34 @@ def test_qecho_works(executor):
assert result == ["asdf"]


@dbtest
def test_logfile_works(executor):
with tempfile.TemporaryDirectory() as tmpdir:
log_file = f"{tmpdir}/tempfile.log"
cli = PGCli(pgexecute=executor, log_file=log_file)
statement = r"\qecho hello!"
cli.execute_command(statement)
with open(log_file, "r") as f:
log_contents = f.readlines()
assert datetime.datetime.fromisoformat(log_contents[0].strip())
assert log_contents[1].strip() == r"\qecho hello!"
assert log_contents[2].strip() == "hello!"


@dbtest
def test_logfile_unwriteable_file(executor):
cli = PGCli(pgexecute=executor)
statement = r"\log-file forbidden.log"
with mock.patch("builtins.open") as mock_open:
mock_open.side_effect = PermissionError(
"[Errno 13] Permission denied: 'forbidden.log'"
)
result = run(executor, statement, pgspecial=cli.pgspecial)
assert result == [
"[Errno 13] Permission denied: 'forbidden.log'\nLogfile capture disabled"
]


@dbtest
def test_watch_works(executor):
cli = PGCli(pgexecute=executor)
Expand Down

0 comments on commit ce7f76a

Please sign in to comment.