Skip to content

Commit

Permalink
Merge pull request #790 from doronz88/bugfix/tunneld-bugs
Browse files Browse the repository at this point in the history
tunneld: refactor implementation so each new address gets its own task
  • Loading branch information
doronz88 authored Jan 18, 2024
2 parents fcabc8f + 6b713b2 commit 56d74e8
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 139 deletions.
12 changes: 10 additions & 2 deletions pymobiledevice3/cli/cli_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,19 @@ def __init__(self, *args, **kwargs):
self.params[:0] = [
click.Option(('verbosity', '-v', '--verbose'), count=True, callback=set_verbosity, expose_value=False),
]


class BaseServiceProviderCommand(BaseCommand):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.params[:0] = [
click.Option(('verbosity', '-v', '--verbose'), count=True, callback=set_verbosity, expose_value=False),
]
self.service_provider = None
self.callback = choose_service_provider(self.callback)


class LockdownCommand(BaseCommand):
class LockdownCommand(BaseServiceProviderCommand):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.usbmux_address = None
Expand Down Expand Up @@ -182,7 +190,7 @@ def udid(self, ctx, param: str, value: str) -> Optional[LockdownClient]:
[create_using_usbmux(serial=device.serial, usbmux_address=self.usbmux_address) for device in devices])


class RSDCommand(BaseCommand):
class RSDCommand(BaseServiceProviderCommand):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.params[:0] = [
Expand Down
10 changes: 5 additions & 5 deletions pymobiledevice3/cli/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import click

from pymobiledevice3.cli.cli_common import RSDCommand, print_json, prompt_device_list, sudo_required
from pymobiledevice3.cli.cli_common import BaseCommand, RSDCommand, print_json, prompt_device_list, sudo_required
from pymobiledevice3.common import get_home_folder
from pymobiledevice3.exceptions import NoDeviceConnectedError
from pymobiledevice3.pair_records import PAIRING_RECORD_EXT, get_remote_pairing_record_filename
Expand Down Expand Up @@ -46,7 +46,7 @@ def remote_cli():
pass


@remote_cli.command('tunneld')
@remote_cli.command('tunneld', cls=BaseCommand)
@click.option('--host', default=TUNNELD_DEFAULT_ADDRESS[0])
@click.option('--port', type=click.INT, default=TUNNELD_DEFAULT_ADDRESS[1])
@click.option('-d', '--daemonize', is_flag=True)
Expand All @@ -73,7 +73,7 @@ def cli_tunneld(host: str, port: int, daemonize: bool, protocol: str):
tunneld_runner()


@remote_cli.command('browse')
@remote_cli.command('browse', cls=BaseCommand)
@click.option('--color/--no-color', default=True)
def browse(color: bool):
""" browse devices using bonjour """
Expand Down Expand Up @@ -156,7 +156,7 @@ def select_device(udid: str) -> RemoteServiceDiscoveryService:
return rsd


@remote_cli.command('start-tunnel')
@remote_cli.command('start-tunnel', cls=BaseCommand)
@click.option('--udid', help='UDID for a specific device to look for')
@click.option('--secrets', type=click.File('wt'), help='TLS keyfile for decrypting with Wireshark')
@click.option('--script-mode', is_flag=True,
Expand All @@ -176,7 +176,7 @@ def cli_start_tunnel(udid: str, secrets: TextIO, script_mode: bool, max_idle_tim
debug=True)


@remote_cli.command('delete-pair')
@remote_cli.command('delete-pair', cls=BaseCommand)
@click.option('--udid', help='UDID for a specific device to delete the pairing record of')
@sudo_required
def cli_delete_pair(udid: str):
Expand Down
18 changes: 9 additions & 9 deletions pymobiledevice3/remote/bonjour.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
from socket import AF_INET6, inet_ntop
from typing import List

from ifaddr import Adapter, get_adapters
from ifaddr import get_adapters
from zeroconf import ServiceBrowser, ServiceListener, Zeroconf
from zeroconf.const import _TYPE_AAAA

DEFAULT_BONJOUR_TIMEOUT = 1


class RemotedListener(ServiceListener):
def __init__(self, adapter: Adapter):
def __init__(self, ip: str):
super().__init__()
self.adapter = adapter
self.ip = ip
self.addresses: List[str] = []

def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
Expand All @@ -22,7 +22,7 @@ def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
entries_with_name = zc.cache.async_entries_with_name(service_info.server)
for entry in entries_with_name:
if entry.type == _TYPE_AAAA:
self.addresses.append(inet_ntop(AF_INET6, entry.address) + '%' + self.adapter.nice_name)
self.addresses.append(inet_ntop(AF_INET6, entry.address) + '%' + self.ip.split('%')[1])

def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None:
pass
Expand All @@ -38,16 +38,16 @@ class BonjourQuery:
listener: RemotedListener


def query_bonjour(adapter: Adapter) -> BonjourQuery:
zc = Zeroconf(interfaces=[adapter.ips[0].ip[0]])
listener = RemotedListener(adapter)
def query_bonjour(ip: str) -> BonjourQuery:
zc = Zeroconf(interfaces=[ip])
listener = RemotedListener(ip)
service_browser = ServiceBrowser(zc, '_remoted._tcp.local.', listener)
return BonjourQuery(zc, service_browser, listener)


def get_remoted_addresses(timeout: int = DEFAULT_BONJOUR_TIMEOUT) -> List[str]:
adapters = [adapter for adapter in get_adapters() if adapter.ips[0].is_IPv6]
bonjour_queries = [query_bonjour(adapter) for adapter in adapters]
ips = [f'{adapter.ips[0].ip[0]}%{adapter.nice_name}' for adapter in get_adapters() if adapter.ips[0].is_IPv6]
bonjour_queries = [query_bonjour(adapter) for adapter in ips]
time.sleep(timeout)
addresses = []
for bonjour_query in bonjour_queries:
Expand Down
214 changes: 91 additions & 123 deletions pymobiledevice3/tunneld.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@
import logging
import os
import signal
import traceback
from contextlib import asynccontextmanager, suppress
from typing import Dict, Tuple
from typing import Dict, List, Optional, Tuple

import fastapi
import ifaddr.netifaces
import uvicorn
import zeroconf
from fastapi import FastAPI
from packaging import version
from zeroconf import IPVersion
from zeroconf.asyncio import AsyncZeroconf
from ifaddr import get_adapters

from pymobiledevice3.exceptions import InterfaceIndexNotFoundError
from pymobiledevice3.remote.bonjour import query_bonjour
from pymobiledevice3.remote.common import TunnelProtocol
from pymobiledevice3.remote.core_device_tunnel_service import TunnelResult
from pymobiledevice3.remote.module_imports import start_tunnel
from pymobiledevice3.remote.remote_service_discovery import RemoteServiceDiscoveryService
from pymobiledevice3.remote.remote_service_discovery import RSD_PORT, RemoteServiceDiscoveryService
from pymobiledevice3.remote.utils import stop_remoted

logger = logging.getLogger(__name__)
Expand All @@ -29,136 +27,107 @@


@dataclasses.dataclass
class Tunnel:
rsd: RemoteServiceDiscoveryService
task: asyncio.Task = None
address: Tuple[str, int] = UNINIT_ADDRESS
class TunnelTask:
task: asyncio.Task
udid: Optional[str] = None
tunnel: Optional[TunnelResult] = None


class TunneldCore:
def __init__(self, protocol: TunnelProtocol = TunnelProtocol.QUIC):
self.adapters: Dict[int, str] = {}
self.active_tunnels: Dict[int, Tunnel] = {}
self.protocol = protocol
self._type = '_remoted._tcp.local.'
self._name = 'ncm._remoted._tcp.local.'
self._interval = .5
self.tasks = []
self.tasks: List[asyncio.Task] = []
self.tunnel_tasks: Dict[str, TunnelTask] = {}

def start(self) -> None:
""" Register all tasks """
self.tasks = [
asyncio.create_task(self.update_adapters(), name='update_adapters'),
asyncio.create_task(self.remove_detached_devices(), name='remove_detached_devices'),
asyncio.create_task(self.discover_new_devices(), name='discover_new_devices'),
asyncio.create_task(self.monitor_adapters(), name='monitor_adapters'),
]

async def monitor_adapters(self):
previous_ips = []
while True:
current_ips = [f'{adapter.ips[0].ip[0]}%{adapter.nice_name}' for adapter in get_adapters() if
adapter.ips[0].is_IPv6]

added = [ip for ip in current_ips if ip not in previous_ips]
removed = [ip for ip in previous_ips if ip not in current_ips]

previous_ips = current_ips

logger.debug(f'added interfaces: {added}')
logger.debug(f'removed interfaces: {removed}')

for ip in removed:
if ip in self.tunnel_tasks:
self.tunnel_tasks[ip].task.cancel()
await self.tunnel_tasks[ip].task

for ip in added:
self.tunnel_tasks[ip] = TunnelTask(
task=asyncio.create_task(self.handle_new_ip(ip), name='handle_new_address'))

# wait before re-iterating
await asyncio.sleep(1)

async def handle_new_ip(self, ip: str):
tun = None
try:
# browse the adapter for CoreDevices
query = query_bonjour(ip)

# wait the response to arrive
await asyncio.sleep(1)

# validate a CoreDevice was indeed found
addresses = query.listener.addresses
if not addresses:
return
peer_address = addresses[0]

# establish an untrusted RSD handshake
rsd = RemoteServiceDiscoveryService((peer_address, RSD_PORT))
with stop_remoted():
try:
rsd.connect()
except ConnectionRefusedError:
return

# populate the udid from the untrusted RSD information
self.tunnel_tasks[ip].udid = rsd.udid

# establish a trusted tunnel
async with start_tunnel(rsd, protocol=self.protocol) as tun:
self.tunnel_tasks[ip].tunnel = tun
logger.info(f'Created tunnel --rsd {tun.address} {tun.port}')
await tun.client.wait_closed()

except asyncio.CancelledError:
pass
except Exception:
logger.error(traceback.format_exc())
finally:
if tun is not None:
logger.info(f'disconnected from tunnel --rsd {tun.address} {tun.port}')

if ip in self.tunnel_tasks:
# in case the tunnel was removed just now
self.tunnel_tasks.pop(ip)

async def close(self):
""" close all tasks """
for task in self.tasks:
for task in self.tasks + [tunnel_task.task for tunnel_task in self.tunnel_tasks.values()]:
task.cancel()
with suppress(asyncio.CancelledError):
await task

def clear(self) -> None:
""" Clear active tunnels """
for udid, tunnel in self.active_tunnels.items():
logger.info(f'Removing tunnel {tunnel.address}')
tunnel.rsd.close()
for udid, tunnel in self.tunnel_tasks.items():
logger.info(f'Removing tunnel {tunnel}')
tunnel.task.cancel()
self.active_tunnels = {}

async def handle_new_tunnel(self, tun: Tunnel) -> None:
""" Create new tunnel """
async with start_tunnel(tun.rsd, protocol=self.protocol) as tunnel_result:
tun.address = tunnel_result.address, tunnel_result.port
logger.info(f'Created tunnel --rsd {tun.address[0]} {tun.address[1]}')
await tunnel_result.client.wait_closed()

@staticmethod
async def connect_rsd(address: str, port: int) -> RemoteServiceDiscoveryService:
""" Connect to RSD """
with stop_remoted():
rsd = RemoteServiceDiscoveryService((address, port))
rsd.connect()
return rsd

async def update_adapters(self) -> None:
""" Constantly updates the 'adapters' dictionary with IPv6 addresses linked to network interfaces """
while True:
self.adapters = {iface.index: addr.ip[0] for iface in ifaddr.get_adapters() for addr in iface.ips if
addr.is_IPv6}
await asyncio.sleep(self._interval)

async def remove_detached_devices(self) -> None:
""" Continuously checks if adapters were removed and removes associated tunnels """
while True:
# Find active tunnels that are no longer associated with adapters
diff = list(set(self.active_tunnels.keys()) - set(self.adapters.keys()))
# For each detached tunnel, cancel its task, log the removal, and remove it from the active tunnels
for k in diff:
self.active_tunnels[k].task.cancel()
self.active_tunnels[k].rsd.close()
logger.info(f'Removing tunnel {self.active_tunnels[k].address}')
self.active_tunnels.pop(k)

await asyncio.sleep(self._interval)

def get_interface_index(self, address: str) -> int:
"""
To address the issue of an unknown IPv6 scope id for a device, we employ a workaround.
We maintain a mapping that associates the scope id with the adapter address.
To resolve this, we remove the last segment (quartet) from both the adapter address and the target address.
If there is a match, we retrieve the scope id associated with that adapter and use it.
Disclaimer: Matching addresses based on their segments may result in interface collision in specific network
configurations.
"""
address_segments = address.split(':')[:-1]
for k, v in self.adapters.items():
if address_segments != v.split(':')[:-1]:
continue
return k
raise InterfaceIndexNotFoundError(address=address)

async def discover_new_devices(self) -> None:
""" Continuously scans for devices advertising 'RSD' through IPv6 adapters """
while True:
# Search for devices advertising the specified service type and name
async with AsyncZeroconf(ip_version=IPVersion.V6Only) as aiozc:
try:
info = await aiozc.async_get_service_info(self._type, self._name, timeout=ZEROCONF_TIMEOUT)
except zeroconf.Error as e:
logger.warning(e)
continue
if info is None:
continue
# Extract device details
addr = info.parsed_addresses(IPVersion.V6Only)[0]
try:
interface_index = self.get_interface_index(addr)
except InterfaceIndexNotFoundError as e:
logger.warning(f'Failed to find interface index for {e.address}')
continue
if interface_index in self.active_tunnels:
continue
# Connect to the discovered device
addr = f'{addr}%{interface_index}'
try:
rsd = await self.connect_rsd(addr, info.port)
except (TimeoutError, ConnectionError, OSError):
logger.warning(f'Failed to connect rsd to {addr}')
continue
# Check unsupported devices with a product version below a minimum threshold
if version.parse(rsd.product_version) < version.parse(MIN_VERSION):
logger.warning(f'{rsd.udid} Unsupported device {rsd.product_version} < {MIN_VERSION}')
continue
logger.info(f'Creating tunnel for {addr}')
tunnel = Tunnel(rsd)
# Add the tunnel to the active tunnels and start a handling task
tunnel.task = asyncio.create_task(self.handle_new_tunnel(tunnel))
self.active_tunnels[interface_index] = tunnel
await asyncio.sleep(self._interval)
self.tunnel_tasks = {}


class TunneldRunner:
Expand All @@ -184,13 +153,12 @@ async def lifespan(app: FastAPI):
self._tunneld_core = TunneldCore(protocol)

@self._app.get('/')
async def list_tunnels() -> Dict[str, Tuple[str, int]]:
async def list_tunnels() -> Dict[str, Tuple]:
""" Retrieve the available tunnels and format them as {UUID: TUNNEL_ADDRESS} """
tunnels = {}
for k, v in self._tunneld_core.active_tunnels.items():
if v.address == UNINIT_ADDRESS:
continue
tunnels[v.rsd.udid] = v.address
for ip, active_tunnel in self._tunneld_core.tunnel_tasks.items():
if (active_tunnel.udid is not None) and (active_tunnel.tunnel is not None):
tunnels[active_tunnel.udid] = (active_tunnel.tunnel.address, active_tunnel.tunnel.port)
return tunnels

@self._app.get('/shutdown')
Expand Down

0 comments on commit 56d74e8

Please sign in to comment.