Skip to content

Commit

Permalink
♻️ 拆分get_comb_id功能
Browse files Browse the repository at this point in the history
  • Loading branch information
AzideCupric committed Mar 29, 2024
1 parent 2d28ac2 commit 4423ddd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
11 changes: 5 additions & 6 deletions nonebot_bison/platform/ceobecanteen/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@


class CeobeDataSourceCache:
"""数据源缓存"""
"""数据源缓存, 以unique_id为key存储数据源"""

def __init__(self):
self._cache = ExpiringDict[str, CeobeTarget](capacity=100, default_age=timedelta(days=7))
self._cache = ExpiringDict[str, CeobeTarget](capacity=100, default_age=timedelta(days=1))
self.client = CeobeClient()
self.url = DATASOURCE_URL
self.init_requested = False

@property
def cache(self) -> MappingProxyType[str, CeobeTarget]:
Expand All @@ -46,10 +45,10 @@ async def refresh_data_sources(self):
self._cache[ds.unique_id] = ds
return self.cache

async def get_all(self):
if not self.init_requested:
async def get_all(self, force_refresh: bool = False):
"""获取所有数据源, 如果缓存为空则尝试刷新缓存"""
if not self.cache or force_refresh:
await self.refresh_data_sources()
self.init_requested = True
return self.cache

def select_one(self, cond_func: Callable[[CeobeTarget], bool]) -> CeobeTarget | None:
Expand Down
32 changes: 17 additions & 15 deletions nonebot_bison/platform/ceobecanteen/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,27 @@ class CeobeCanteen(NewMessage):
data_source_cache = CeobeDataSourceCache()
cache_store = SimpleCache()

async def get_comb_id(self, force_refresh: bool = False) -> str | None:
"""获取数据源的组合id
async def get_comb_id(self, target_uuids: list[str]):
"""获取数据源的组合id"""
payload = {"datasource_push": target_uuids}
logger.trace(payload)
resp = await self.client.post(
COMB_ID_URL,
json=payload,
)
comb_id = process_response(resp, CombIdResponse).data["datasource_comb_id"]
logger.trace(f"get comb_id: {comb_id}")
return comb_id

获取 "全部数据源" 的组合id,
获取到的comb_id会缓存12小时
"""
async def get_comb_id_for_all(self, force_refresh: bool = False):
"""获取 "全部数据源" 的组合id,获取到的comb_id会缓存12小时"""
if self.cache_store["comb_id"] is None or force_refresh:
logger.trace("no comb_id, request")
target_uuids = (await self.data_source_cache.get_all()).keys()
payload = {"datasource_push": list(target_uuids)}
logger.trace(payload)
resp = await self.client.post(
COMB_ID_URL,
json=payload,
)
comb_id = process_response(resp, CombIdResponse).data["datasource_comb_id"]
logger.trace(f"get comb_id: {comb_id}")
target_uuids = (await self.data_source_cache.get_all(force_refresh)).keys()
comb_id = await self.get_comb_id(list(target_uuids))
self.cache_store["comb_id", timedelta(hours=12)] = comb_id

logger.debug(f"use comb_id: {self.cache_store['comb_id']}")
return self.cache_store["comb_id"]

async def get_cookie_id(self, comb_id: str):
Expand Down Expand Up @@ -109,7 +111,7 @@ def update(cookie_id: str, cookies: list[CeobeCookie]):
return cast(list[CeobeCookie] | None, self.cache_store["cookies"])

async def fetch_ceobe_cookies(self) -> list[CeobeCookie]:
comb_id = await self.get_comb_id()
comb_id = await self.get_comb_id_for_all()
if not comb_id:
return []
cookie_id = await self.get_cookie_id(comb_id)
Expand Down

0 comments on commit 4423ddd

Please sign in to comment.