Skip to content

Commit

Permalink
pyfmt
Browse files Browse the repository at this point in the history
  • Loading branch information
markokr committed Oct 24, 2023
1 parent 0b846b6 commit 4a435ed
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 50 deletions.
9 changes: 5 additions & 4 deletions skytools/adminscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import inspect
import sys
from typing import Sequence, Optional, Any, Mapping, Callable
from typing import Any, Callable, Mapping, Optional, Sequence

import skytools

Expand Down Expand Up @@ -53,8 +53,8 @@ def work(self) -> Optional[int]:

# check if correct number of arguments
(
args, varargs, ___varkw, ___defaults,
___kwonlyargs, __kwonlydefaults, ___annotations,
args, varargs, ___varkw, ___defaults,
___kwonlyargs, __kwonlydefaults, ___annotations,
) = inspect.getfullargspec(fn)
n_args = len(args) - 1 # drop 'self'
if varargs is None and n_args != len(cmdargs):
Expand Down Expand Up @@ -82,7 +82,8 @@ def fetch_list(self, db: Connection, sql: str, args: ExecuteParams, keycol: Opti
res = [r[keycol] for r in rows]
return res

def display_table(self, db: Connection, desc: str, sql: str, args: ExecuteParams = (), fields: Sequence[str] = (), fieldfmt: Optional[Mapping[str, Callable[[Any], str]]]=None) -> int:
def display_table(self, db: Connection, desc: str, sql: str, args: ExecuteParams = (),
fields: Sequence[str] = (), fieldfmt: Optional[Mapping[str, Callable[[Any], str]]] = None) -> int:
"""Display multirow query as a table."""

self.log.debug("display_table: %s", skytools.quote_statement(sql, args))
Expand Down
2 changes: 1 addition & 1 deletion skytools/apipkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import os
import sys
from typing import List
from types import ModuleType
from typing import List

__version__ = "1.5"

Expand Down
19 changes: 11 additions & 8 deletions skytools/basetypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

import abc
import io
import typing
import types

import typing
from typing import (
IO, Any, Mapping, Optional, Sequence, Tuple, Type, Union,
KeysView, ValuesView, ItemsView, Iterator,
IO, Any, ItemsView, Iterator, KeysView, Mapping,
Optional, Sequence, Tuple, Type, Union, ValuesView,
)

try:
Expand Down Expand Up @@ -51,7 +50,8 @@ def execute(self, sql: str, params: Optional[ExecuteParams] = None) -> None: rai
def fetchall(self) -> Sequence[DictRow]: raise NotImplementedError
def fetchone(self) -> DictRow: raise NotImplementedError
def __enter__(self) -> "Cursor": raise NotImplementedError
def __exit__(self, typ: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[types.TracebackType]) -> None:
def __exit__(self, typ: Optional[Type[BaseException]], exc: Optional[BaseException],
tb: Optional[types.TracebackType]) -> None:
raise NotImplementedError
def copy_expert(
self, sql: str,
Expand All @@ -61,7 +61,8 @@ def copy_expert(
raise NotImplementedError
def fileno(self) -> int: raise NotImplementedError
@property
def description(self) -> Sequence[Tuple[str, int, int, int, Optional[int], Optional[int], None]]: raise NotImplementedError
def description(self) -> Sequence[Tuple[str, int, int, int, Optional[int],
Optional[int], None]]: raise NotImplementedError
@property
def connection(self) -> "Connection": raise NotImplementedError

Expand All @@ -80,9 +81,11 @@ def set_client_encoding(self, encoding: str) -> None: raise NotImplementedError
@property
def server_version(self) -> int: raise NotImplementedError
def __enter__(self) -> "Connection": raise NotImplementedError
def __exit__(self, typ: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[types.TracebackType]) -> None:
def __exit__(self, typ: Optional[Type[BaseException]], exc: Optional[BaseException],
tb: Optional[types.TracebackType]) -> None:
raise NotImplementedError


class Runnable(Protocol):
def run(self) -> None: raise NotImplementedError

Expand All @@ -98,7 +101,7 @@ def fileno(self) -> int: raise NotImplementedError
from typing_extensions import Buffer # type: ignore
except ImportError:
if typing.TYPE_CHECKING:
from _typeshed import Buffer # type: ignore
from _typeshed import Buffer # type: ignore
else:
try:
from collections.abc import Buffer # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion skytools/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
import sys
import time
from typing import IO, List, Optional, Sequence, Tuple, Dict, cast, Mapping, Any
from typing import IO, Any, Dict, List, Mapping, Optional, Sequence, Tuple, cast

import skytools

Expand Down
6 changes: 4 additions & 2 deletions skytools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ExtendedInterpolation, Interpolation, InterpolationDepthError,
InterpolationError, NoOptionError, NoSectionError, RawConfigParser,
)
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, MutableMapping, Set
from typing import Dict, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple

import skytools

Expand Down Expand Up @@ -280,6 +280,7 @@ def items(self) -> Sequence[Tuple[str, str]]:
#ParserState = ConfigParser
ParserLoop = Set[Tuple[str, str]]


class ExtendedInterpolationCompat(Interpolation):
_EXT_VAR_RX = r'\$\$|\$\{[^(){}]+\}'
_OLD_VAR_RX = r'%%|%\([^(){}]+\)s'
Expand All @@ -297,7 +298,8 @@ def before_set(self, parser: ParserState, section: str, option: str, value: str)
raise ValueError("invalid interpolation syntax in %r" % value)
return value

def _interpolate_ext(self, dst: List[str], parser: ParserState, section: str, option: str, rawval: str, defaults: ParserSection, loop_detect: ParserLoop) -> None:
def _interpolate_ext(self, dst: List[str], parser: ParserState, section: str, option: str,
rawval: str, defaults: ParserSection, loop_detect: ParserLoop) -> None:
if not rawval:
return

Expand Down
16 changes: 10 additions & 6 deletions skytools/dbservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

import logging
from typing import List, Optional, Sequence, Any, Dict, Union, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import skytools
from skytools import dbdict
Expand Down Expand Up @@ -202,7 +202,8 @@ def __init__(self, context: str, global_dict: Optional[Dict[str, Any]] = None) -

# error and message handling

def tell_user(self, severity: str, code: str, message: str, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> None:
def tell_user(self, severity: str, code: str, message: str,
params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> None:
""" Adds another message to the set of messages to be sent back to user
If error message then can_save is set false
If fatal message then error or found errors are raised at once
Expand All @@ -229,7 +230,8 @@ def raise_if_errors(self) -> None:

# run sql meant mostly for select but not limited to

def create_query(self, sql: str, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> skytools.PLPyQueryBuilder:
def create_query(self, sql: str, params: Optional[Dict[str, Any]]
= None, **kvargs: Any) -> skytools.PLPyQueryBuilder:
""" Returns initialized querybuilder object for building complex dynamic queries
"""
params = params or kvargs
Expand Down Expand Up @@ -292,13 +294,15 @@ def return_next(self, rows: List[dbdict], res_name: str, severity: Optional[str]
self.tell_user(severity, "dbsXXXX", "No matching records found")
return rows

def return_next_sql(self, sql: str, params: Optional[Dict[str, Any]], res_name: str, severity: Optional[str] = None) -> List[dbdict]:
def return_next_sql(self, sql: str, params: Optional[Dict[str, Any]],
res_name: str, severity: Optional[str] = None) -> List[dbdict]:
""" Exectes query and adds recors resultset
"""
rows = self.run_query(sql, params)
return self.return_next(rows, res_name, severity)

def retval(self, service_name: Optional[str] = None, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> List[Tuple[str, str, str]]:
def retval(self, service_name: Optional[str] = None, params: Optional[Dict[str, Any]]
= None, **kvargs: Any) -> List[Tuple[str, str, str]]:
""" Return collected resultsets and append to the end messages to the users
Method is called usually as last statement in dbservice to return the results
Also converts results into desired format
Expand All @@ -308,7 +312,7 @@ def retval(self, service_name: Optional[str] = None, params: Optional[Dict[str,
if len(self.messages):
self.return_next(self.messages, "_status") # type: ignore
if self.sqls is not None and len(self.sqls):
self.return_next(self.sqls, "_sql") # type: ignore
self.return_next(self.sqls, "_sql") # type: ignore
results: List[Tuple[str, str, str]] = []
for r in self._retval:
res_name = r[0]
Expand Down
10 changes: 6 additions & 4 deletions skytools/dbstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# pylint:disable=arguments-renamed

import re
from typing import List, Optional, Type, Tuple, TypeVar, Any
from logging import Logger
from typing import Any, List, Optional, Tuple, Type, TypeVar

import skytools
from skytools import quote_fqident, quote_ident

from skytools.basetypes import Cursor, DictRow

__all__ = (
Expand Down Expand Up @@ -40,6 +39,7 @@
# Utility functions
#


def find_new_name(curs: Optional[Cursor], name: str) -> str:
"""Create new object name for case the old exists.
Expand Down Expand Up @@ -533,7 +533,8 @@ class TTable(TElem):
col_list: List[TColumn]
dist_key_list: Optional[List[TGPDistKey]]

def __init__(self, table_name: str, col_list: List[TColumn], dist_key_list: Optional[List[TGPDistKey]] = None) -> None:
def __init__(self, table_name: str, col_list: List[TColumn],
dist_key_list: Optional[List[TGPDistKey]] = None) -> None:
self.name = table_name
self.col_list = col_list
self.dist_key_list = dist_key_list
Expand Down Expand Up @@ -668,7 +669,8 @@ def _load_elem(self, curs: Cursor, name: str, args: Any, eclass: Type[T]) -> Lis
elem_list.append(eclass(name, row))
return elem_list

def create(self, curs: Cursor, objs: int, new_table_name: Optional[str] = None, log: Optional[Logger] = None) -> None:
def create(self, curs: Cursor, objs: int,
new_table_name: Optional[str] = None, log: Optional[Logger] = None) -> None:
"""Issues CREATE statements for requested set of objects.
If new_table_name is giver, creates table under that name
Expand Down
6 changes: 3 additions & 3 deletions skytools/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

import re
from typing import Iterator, List, Optional, Sequence, Tuple, Dict, Union
from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Union

import skytools

Expand Down Expand Up @@ -142,7 +142,7 @@ def _create_dbdict(self, fields: List[str], values: List[str]) -> skytools.dbdic
return skytools.dbdict(zip(fields2, values2))

def parse_sql(self, op: str, sql: str, pklist: Optional[Sequence[str]] = None, splitkeys: bool = False
) -> Union[skytools.dbdict, Tuple[skytools.dbdict, skytools.dbdict]]:
) -> Union[skytools.dbdict, Tuple[skytools.dbdict, skytools.dbdict]]:
"""Main entry point."""
if pklist is None:
self.pklist = []
Expand Down Expand Up @@ -255,7 +255,7 @@ def parse_tabbed_table(txt: str) -> List[Dict[str, str]]:
def sql_tokenizer(
sql: str, standard_quoting: bool = False, ignore_whitespace: bool = False,
fqident: bool = False, show_location: bool = False
) -> Iterator[Union[Tuple[str, str], Tuple[str, str, int]]]:
) -> Iterator[Union[Tuple[str, str], Tuple[str, str, int]]]:
r"""Parser SQL to tokens.
Iterator, returns (toktype, tokstr) tuples.
Expand Down
4 changes: 2 additions & 2 deletions skytools/plpy_applyrow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""PLPY helper module for applying row events from pgq.logutriga().
"""

from typing import Sequence, Optional
from typing import Optional, Sequence

import skytools

Expand All @@ -11,7 +11,7 @@
pass


## TODO: automatic fkey detection
# TODO: automatic fkey detection
# find FK columns
FK_SQL = """
SELECT (SELECT array_agg( (SELECT attname::text FROM pg_attribute
Expand Down
5 changes: 3 additions & 2 deletions skytools/querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import json
import re
from functools import lru_cache
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, Tuple, cast
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast

import skytools

Expand Down Expand Up @@ -141,7 +141,8 @@ def get_sql(self, param_type: int = PARAM_INLINE) -> str:
tmp = [str(part) for part in self._sql_parts]
return "".join(tmp)

def _add_expr(self, pfx: str, expr: str, params: Optional[Mapping[str, Any]], sql_type: str, required: bool) -> None:
def _add_expr(self, pfx: str, expr: str,
params: Optional[Mapping[str, Any]], sql_type: str, required: bool) -> None:
parts: List[Union[str, QArg]] = []
types: List[str] = []
values: List[Any] = []
Expand Down
28 changes: 16 additions & 12 deletions skytools/scripting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import signal
import sys
import time
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, Callable, Type, cast
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast

import skytools
import skytools.skylog

from .basetypes import Connection, Runnable, Cursor, DictRow, ExecuteParams
from .basetypes import Connection, Cursor, DictRow, ExecuteParams, Runnable

try:
import skytools.installer_config
Expand Down Expand Up @@ -761,7 +761,7 @@ def shutdown(self) -> None:


##
## DBScript
# DBScript
##

#: how old connections need to be closed
Expand Down Expand Up @@ -819,8 +819,8 @@ def add_connect_string_profile(self, connstr: str, profile: Optional[str]) -> st
return connstr

def get_database(self, dbname: str, autocommit: int = 0, isolation_level: int = -1,
cache: Optional[str] = None, connstr: Optional[str] = None,
profile: Optional[str] = None) -> Connection:
cache: Optional[str] = None, connstr: Optional[str] = None,
profile: Optional[str] = None) -> Connection:
"""Load cached database connection.
User must not store it permanently somewhere,
Expand Down Expand Up @@ -948,7 +948,8 @@ def sleep(self, secs: float) -> None:
self.log.info('wait canceled')
return None

def _exec_cmd(self, curs: Cursor, sql: str, args: ExecuteParams, quiet: bool = False, prefix: Optional[str] = None) -> Tuple[bool, Sequence[DictRow]]:
def _exec_cmd(self, curs: Cursor, sql: str, args: ExecuteParams, quiet: bool = False,
prefix: Optional[str] = None) -> Tuple[bool, Sequence[DictRow]]:
"""Internal tool: Run SQL on cursor."""
if self.options.verbose:
self.log.debug("exec_cmd: %s", skytools.quote_statement(sql, args))
Expand Down Expand Up @@ -984,7 +985,8 @@ def _exec_cmd(self, curs: Cursor, sql: str, args: ExecuteParams, quiet: bool = F
ok = False
return (ok, rows)

def _exec_cmd_many(self, curs: Cursor, sql: str, baseargs: List[Any], extra_list: Sequence[Any], quiet:bool=False, prefix:Optional[str]=None) -> Tuple[bool, Sequence[DictRow]]:
def _exec_cmd_many(self, curs: Cursor, sql: str, baseargs: List[Any], extra_list: Sequence[Any],
quiet: bool = False, prefix: Optional[str] = None) -> Tuple[bool, Sequence[DictRow]]:
"""Internal tool: Run SQL on cursor multiple times."""
ok = True
rows: List[DictRow] = []
Expand All @@ -996,7 +998,7 @@ def _exec_cmd_many(self, curs: Cursor, sql: str, baseargs: List[Any], extra_list
return (ok, rows)

def exec_cmd(self, db_or_curs: Union[Connection, Cursor], q: str, args: ExecuteParams,
commit: bool = True, quiet: bool = False, prefix: Optional[str] = None) -> Sequence[DictRow]:
commit: bool = True, quiet: bool = False, prefix: Optional[str] = None) -> Sequence[DictRow]:
"""Run SQL on db with code/value error handling."""

db: Optional[Connection]
Expand All @@ -1022,7 +1024,7 @@ def exec_cmd(self, db_or_curs: Union[Connection, Cursor], q: str, args: ExecuteP
sys.exit(1)

def exec_cmd_many(self, db_or_curs: Union[Connection, Cursor], sql: str, baseargs: List[Any], extra_list: Sequence[Any],
commit: bool = True, quiet: bool = False, prefix: Optional[str] = None) -> Sequence[DictRow]:
commit: bool = True, quiet: bool = False, prefix: Optional[str] = None) -> Sequence[DictRow]:
"""Run SQL on db multiple times."""
if hasattr(db_or_curs, 'cursor'):
db = cast(Connection, db_or_curs)
Expand All @@ -1043,7 +1045,8 @@ def exec_cmd_many(self, db_or_curs: Union[Connection, Cursor], sql: str, basearg
# error is already logged
sys.exit(1)

def execute_with_retry(self, dbname: str, stmt: str, args: List[Any], exceptions: Optional[Sequence[Type[Exception]]] = None) -> Tuple[int, Cursor]:
def execute_with_retry(self, dbname: str, stmt: str,
args: List[Any], exceptions: Optional[Sequence[Type[Exception]]] = None) -> Tuple[int, Cursor]:
""" Execute SQL and retry if it fails.
Return number of retries and current valid cursor, or raise an exception.
"""
Expand Down Expand Up @@ -1133,7 +1136,8 @@ class DBCachedConn:
setup_func: Optional[SetupFunc]
listen_channel_list: Sequence[str]

def __init__(self, name: str, loc: str, max_age:int=DEF_CONN_AGE, verbose:bool=False, setup_func: Optional[SetupFunc] = None, channels: Sequence[str] = ()) -> None:
def __init__(self, name: str, loc: str, max_age: int = DEF_CONN_AGE, verbose: bool = False,
setup_func: Optional[SetupFunc] = None, channels: Sequence[str] = ()) -> None:
self.name = name
self.loc = loc
self.conn = None
Expand All @@ -1149,7 +1153,7 @@ def fileno(self) -> Optional[int]:
return None
return self.conn.cursor().fileno()

def get_connection(self, isolation_level:int=-1, listen_channel_list: Sequence[str]=()) -> Connection:
def get_connection(self, isolation_level: int = -1, listen_channel_list: Sequence[str] = ()) -> Connection:

# default isolation_level is READ COMMITTED
if isolation_level < 0:
Expand Down
Loading

0 comments on commit 4a435ed

Please sign in to comment.