Skip to content

Commit

Permalink
refactoring: add typing hints
Browse files Browse the repository at this point in the history
  • Loading branch information
xyb committed Sep 20, 2024
1 parent 634dd56 commit 89ed4d7
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 88 deletions.
79 changes: 55 additions & 24 deletions task/baidupcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from pathlib import Path
from pathlib import PurePosixPath
from time import sleep
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional

from baidupcs_py.baidupcs import BaiduPCSApi
from baidupcs_py.baidupcs import BaiduPCSError
Expand All @@ -27,20 +32,30 @@ class CaptchaRequired(ValueError):
pass


def get_baidupcs_client():
def get_baidupcs_client() -> "BaiduPCSClient":
return BaiduPCSClient(
settings.PAN_BAIDU_BDUSS,
cookies2dict(settings.PAN_BAIDU_COOKIES),
)


class BaiduPCSClient:
def __init__(self, bduss, cookies, api=None):
def __init__(
self,
bduss: str,
cookies: Dict[str, str],
api: Optional[BaiduPCSApi] = None,
):
self.bduss = bduss
self.cookies = cookies
self.api = api if api else BaiduPCSApi(bduss=bduss, cookies=cookies)

def list_files(self, remote_dir, retry=3, fail_silent=False):
def list_files(
self,
remote_dir: str,
retry: int = 3,
fail_silent: bool = False,
) -> List[Dict[str, Any]]:
while True:
try:
files = self.api.list(remote_dir, recursive=True)
Expand Down Expand Up @@ -72,13 +87,13 @@ def list_files(self, remote_dir, retry=3, fail_silent=False):

def save_shared_link(
self,
remote_dir,
link,
password=None,
callback_save_captcha=None,
remote_dir: str,
link: str,
password: Optional[str] = None,
callback_save_captcha: Optional[Callable] = None,
captcha_id: str = "",
captcha_code: str = "",
):
) -> None:
save_shared(
self,
link,
Expand All @@ -89,7 +104,12 @@ def save_shared_link(
captcha_code=captcha_code,
)

def download_dir(self, remote_dir, local_dir, sample_size=0):
def download_dir(
self,
remote_dir: str,
local_dir: str,
sample_size: int = 0,
) -> None:
for file in self.list_files(remote_dir):
if not file["is_file"]:
continue
Expand All @@ -99,7 +119,13 @@ def download_dir(self, remote_dir, local_dir, sample_size=0):
file_size = file["size"]
self.download_file(remote_path, local_dir_, file_size, sample_size)

def download_file(self, remote_path, local_dir, file_size, sample_size=0):
def download_file(
self,
remote_path: str,
local_dir: str,
file_size: int,
sample_size: int = 0,
) -> Optional[int]:
local_path = Path(local_dir) / basename(remote_path)
logger.info(f" {remote_path} -> {local_path}")
if match_regex(str(remote_path), settings.IGNORE_PATH_RE):
Expand Down Expand Up @@ -130,17 +156,22 @@ def download_file(self, remote_path, local_dir, file_size, sample_size=0):
total = download_url(local_path, url, headers, limit=sample_size)
return total

def leech(self, remote_dir, local_dir, sample_size=0):
def leech(self, remote_dir: str, local_dir: Path, sample_size: int = 0) -> None:
if not local_dir.exists():
makedirs(local_dir, exist_ok=True)

self.download_dir(remote_dir, local_dir, sample_size=sample_size)

def delete(self, remote_dir):
def delete(self, remote_dir: str) -> None:
self.api.remove(remote_dir)


def remotepath_exists(api, name: str, rd: str, _cache={}) -> bool:
def remotepath_exists(
api: BaiduPCSApi,
name: str,
rd: str,
_cache: Dict[str, set] = {},
) -> bool:
names = _cache.get(rd)
if not names:
names = {PurePosixPath(sp.path).name for sp in api.list(rd)}
Expand All @@ -149,14 +180,14 @@ def remotepath_exists(api, name: str, rd: str, _cache={}) -> bool:


def save_shared(
client,
shared_url,
remotedir,
password=None,
callback_save_captcha=None,
client: BaiduPCSClient,
shared_url: str,
remotedir: str,
password: Optional[str] = None,
callback_save_captcha: Optional[Callable] = None,
captcha_id: str = "",
captcha_code: str = "",
):
) -> None:
assert remotedir.startswith("/"), "`remotedir` must be an absolute path"

shared_url = unify_shared_link(shared_url)
Expand Down Expand Up @@ -267,12 +298,12 @@ def save_shared(


def list_all_sub_paths(
api,
api: BaiduPCSApi,
sharedpath: str,
uk: int,
share_id: int,
bdstoken: str,
):
) -> List[Any]:
sub_paths = []
page = 1
size = 100
Expand All @@ -293,13 +324,13 @@ def list_all_sub_paths(


def access_shared(
client,
client: BaiduPCSClient,
shared_url: str,
password: str,
callback_save_captcha=None,
callback_save_captcha: Optional[Callable] = None,
captcha_id: str = "",
captcha_code: str = "",
):
) -> None:
try:
client.api._baidupcs.access_shared(
shared_url,
Expand Down
5 changes: 3 additions & 2 deletions task/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import requests

from .models import Task
from .serializers import TaskSerializer
from .utils import handle_exception

logger = logging.getLogger(__name__)


def callback(task, action):
def callback(task: Task, action: str) -> None:
url = task.callback
if not url:
return
Expand All @@ -21,4 +22,4 @@ def callback(task, action):
return resp
except Exception as exc:
logger.error(f"Error posting data to callback URL: {url}")
handle_exception(task, exc)
handle_exception(exc)
31 changes: 16 additions & 15 deletions task/leecher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.conf import settings
from django.utils import timezone

from .baidupcs import BaiduPCSClient
from .baidupcs import CaptchaRequired
from .callback import callback
from .models import Task
Expand All @@ -11,13 +12,13 @@
logger = logging.getLogger(__name__)


def start_task(task):
def start_task(task: Task) -> None:
task.status = Task.Status.STARTED
task.started_at = timezone.now()
task.save()


def save_link(client, task):
def save_link(client: "BaiduPCSClient", task: Task) -> None:
def save_captcha(captcha_id, captcha_img_url, content):
task.captcha_required = True
task.captcha = content
Expand Down Expand Up @@ -47,15 +48,15 @@ def save_captcha(captcha_id, captcha_img_url, content):
callback(task, "link_saved")


def set_files(client, task):
def set_files(client: "BaiduPCSClient", task: Task) -> None:
task.set_files(list(client.list_files(task.remote_path)))
task.file_listed_at = timezone.now()
task.save()
logger.info(f"list {task} files succeeded.")
callback(task, "files_ready")


def download_samples(client, task):
def download_samples(client: "BaiduPCSClient", task: Task) -> None:
logger.info("downloading samples...")
client.leech(
remote_dir=task.remote_path,
Expand All @@ -68,7 +69,7 @@ def download_samples(client, task):
callback(task, "sampling_downloaded")


def download(client, task):
def download(client: "BaiduPCSClient", task: Task) -> None:
logger.info("downloading...")
client.leech(
remote_dir=task.remote_path,
Expand All @@ -80,20 +81,20 @@ def download(client, task):
logger.info(f"leech {task} succeeded.")


def task_failed(task, message):
def task_failed(task: Task, message: str) -> None:
task.status = Task.Status.FINISHED
task.finished_at = timezone.now()
task.failed = True
task.message = message[: Task._meta.get_field("message").max_length]
task.save()


def finish_transfer(task):
def finish_transfer(task: Task) -> None:
task.status = Task.Status.TRANSFERRED
task.save()


def transfer(client, task):
def transfer(client: "BaiduPCSClient", task: Task) -> None:
logger.info(f"start transfer {task} ...")
start_task(task)

Expand All @@ -106,41 +107,41 @@ def transfer(client, task):
logging.info(f"captcha required: {task}")
except Exception as e:
logging.error(f"transfer {task} failed.")
task_failed(task, handle_exception(task, e))
task_failed(task, handle_exception(e))


def finish_sampling(task):
def finish_sampling(task: Task) -> None:
task.status = Task.Status.SAMPLING_DOWNLOADED
task.save()


def sampling(client, task):
def sampling(client: "BaiduPCSClient", task: Task) -> None:
logger.info(f"start download sampling of {task}")

try:
download_samples(client, task)
except Exception as e:
logging.error(f"download sampling of {task} failed.")
task_failed(task, handle_exception(task, e))
task_failed(task, handle_exception(e))

finish_sampling(task)
logger.info(f"download sampling of {task} succeed.")


def finish_task(task):
def finish_task(task: Task) -> None:
task.status = Task.Status.FINISHED
task.finished_at = timezone.now()
task.save()


def leech(client, task):
def leech(client: "BaiduPCSClient", task: Task) -> None:
logger.info(f"start leech {task} to {task.data_path}")

try:
download(client, task)
except Exception as e:
logging.error(f"download all files of {task} failed.")
task_failed(task, handle_exception(task, e))
task_failed(task, handle_exception(e))
return

finish_task(task)
Expand Down
Loading

0 comments on commit 89ed4d7

Please sign in to comment.