Skip to content

Commit

Permalink
[resh][feat] Allow defining http headers for the shell via cmd line
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias committed Nov 22, 2023
1 parent edbe7a9 commit b9c57ac
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
10 changes: 5 additions & 5 deletions resotolib/resotolib/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from resotolib.logger import log
from resotolib.args import ArgumentParser
from urllib.parse import urlparse, ParseResult
from typing import Optional
from typing import Optional, Dict


class CLIEnvelope:
Expand Down Expand Up @@ -33,27 +33,27 @@ def add_args(arg_parser: ArgumentParser) -> None:
)


def resotocore_is_up(resotocore_uri: str, timeout: int = 5) -> bool:
def resotocore_is_up(resotocore_uri: str, timeout: int = 5, headers: Optional[Dict[str, str]] = None) -> bool:
ready_uri = f"{resotocore_uri}/system/ready"
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
response = requests.get(ready_uri, timeout=timeout, verify=False)
response = requests.get(ready_uri, timeout=timeout, verify=False, headers=headers)
if response.status_code == 200:
return True
except Exception:
pass
return False


def wait_for_resotocore(resotocore_uri: str, timeout: int = 300) -> None:
def wait_for_resotocore(resotocore_uri: str, timeout: int = 300, headers: Optional[Dict[str, str]] = None) -> None:
start_time = time.time()
core_up = False
wait_time: float = -1
remaining_wait: float = timeout
waitlog = log.info
while wait_time < timeout:
if resotocore_is_up(resotocore_uri):
if resotocore_is_up(resotocore_uri, headers=headers):
core_up = True
break
else:
Expand Down
15 changes: 13 additions & 2 deletions resotoshell/resotoshell/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from signal import SIGTERM
from threading import Event
from typing import Tuple

from prompt_toolkit.formatted_text import FormattedText
from resotoclient.async_client import ResotoClient
Expand All @@ -31,9 +32,10 @@ async def main_async() -> None:
jwt_add_args(arg_parser)
TLSData.add_args(arg_parser, ca_only=True)
args: Namespace = arg_parser.parse_args()
headers = dict(args.add_headers)

try:
wait_for_resotocore(resotocore.http_uri, timeout=args.resotocore_wait)
wait_for_resotocore(resotocore.http_uri, timeout=args.resotocore_wait, headers=headers)
except TimeoutError:
log.fatal(f"resotocore is not online at {resotocore.http_uri}")
sys.exit(1)
Expand Down Expand Up @@ -61,7 +63,7 @@ async def check_system_info() -> None:
cmds, kinds, props = await core_metadata(client)
history = ResotoHistory.default()
session = PromptSession(cmds=cmds, kinds=kinds, props=props, history=history)
shell = Shell(client, True, detect_color_system(args), history=history)
shell = Shell(client, True, detect_color_system(args), history=history, additional_headers=headers)
await repl(shell, session, args)

# update the eventually changed auth token
Expand Down Expand Up @@ -156,6 +158,12 @@ def detect_color_system(args: Namespace) -> str:


def add_args(arg_parser: ArgumentParser) -> None:
def header_value(s: str) -> Tuple[str, str]:
if ":" not in s:
raise ValueError("Header must be in the format key:value")
k, v = s.split(":", 1)
return k, v

arg_parser.add_argument(
"--resotocore-section",
help="All queries are interpreted with this section name. If not set, the server default is used.",
Expand Down Expand Up @@ -200,6 +208,9 @@ def add_args(arg_parser: ArgumentParser) -> None:
action="store_true",
default=False,
)
arg_parser.add_argument(
"--add-headers", help="Add a header to all requests. Format: key:value", nargs="*", type=header_value
)


def main() -> None:
Expand Down
13 changes: 10 additions & 3 deletions resotoshell/resotoshell/authorized_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ class AccessDeniedError(Exception):


async def new_client(args: Namespace) -> ResotoClient:
headers = dict(args.add_headers)
# if a PSK was defined on the command line, use it
if args.psk:
return ResotoClient(
url=resotocore.http_uri,
psk=args.psk,
custom_ca_cert_path=args.ca_cert,
verify=args.verify_certs,
additional_headers=headers,
)

# fetch ssl certificate
Expand All @@ -43,7 +45,12 @@ async def new_client(args: Namespace) -> ResotoClient:
try:
await fetch_auth_header(resotocore.http_uri, ssl=ssl)
# no authorization required
return ResotoClient(url=resotocore.http_uri, custom_ca_cert_path=args.ca_cert, verify=args.verify_certs)
return ResotoClient(
url=resotocore.http_uri,
custom_ca_cert_path=args.ca_cert,
verify=args.verify_certs,
additional_headers=headers,
)
except AccessDeniedError:
config = ReshConfig.default()
if creds := config.valid_credentials(resotocore.http_uri):
Expand All @@ -53,7 +60,7 @@ async def new_client(args: Namespace) -> ResotoClient:
url=resotocore.http_uri,
custom_ca_cert_path=args.ca_cert,
verify=args.verify_certs,
additional_headers={"Authorization": f"{method} {auth_token}"},
additional_headers={**headers, "Authorization": f"{method} {auth_token}"},
)
else:
# No valid credentials found in config file. Start authorization flow
Expand All @@ -70,7 +77,7 @@ async def new_client(args: Namespace) -> ResotoClient:
url=resotocore.http_uri,
custom_ca_cert_path=args.ca_cert,
verify=args.verify_certs,
additional_headers={"Authorization": f"{method} {token}"},
additional_headers={**headers, "Authorization": f"{method} {token}"},
)


Expand Down
4 changes: 3 additions & 1 deletion resotoshell/resotoshell/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
graph: Optional[str] = None,
section: Optional[str] = None,
history: Optional[ResotoHistory] = None,
additional_headers: Optional[Dict[str, str]] = None,
):
self.client = client
self.history = history
Expand All @@ -55,6 +56,7 @@ def __init__(
self.color_depth = color_system_to_color_depth.get(color_system) or ColorDepth.DEPTH_8_BIT
self.graph = graph
self.section = section
self.additional_headers = additional_headers or {}

async def handle_command(
self,
Expand All @@ -63,7 +65,7 @@ async def handle_command(
files: Optional[Dict[str, str]] = None,
no_history: bool = False,
) -> None:
headers: Dict[str, str] = {}
headers: Dict[str, str] = self.additional_headers.copy()
headers.update({"Accept": "text/plain"})
headers.update(additional_headers or {})

Expand Down

0 comments on commit b9c57ac

Please sign in to comment.