fix: correct task count

This commit is contained in:
世界观察日志
2025-08-30 23:13:29 +08:00
parent 1974c84a60
commit 365fa2f8d3
4 changed files with 25 additions and 12 deletions

View File

@@ -12,7 +12,7 @@ from tenacity import retry, retry_if_exception_type, wait_random_exponential, st
from src.config import Config
from src.logger import GlobalLogger
from src.measurer import SpeedMeasurer
from src.measurer import Measurer
from src.models import *
@@ -90,7 +90,7 @@ class WebAPI:
total = int(response.headers.get("Content-Length") if response.headers.get("Content-Length")
else response.headers.get("X-Apple-MS-Content-Length"))
async for chunk in response.aiter_bytes():
it(SpeedMeasurer).record_download(len(chunk))
it(Measurer).record_download(len(chunk))
result.write(chunk)
if len(result.getvalue()) != total:
raise httpx.HTTPError

View File

@@ -4,18 +4,18 @@ import sys
from creart import it
from prompt_toolkit import PromptSession
from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.completion import NestedCompleter
from prompt_toolkit.patch_stdout import patch_stdout
from src.api import WebAPI
from src.config import Config
from src.flags import Flags
from src.grpc.manager import WrapperManager, WrapperManagerException
from src.logger import GlobalLogger
from src.measurer import SpeedMeasurer
from src.measurer import Measurer
from src.rip import on_decrypt_success, on_decrypt_failed, rip_song, rip_album, rip_artist, rip_playlist
from src.url import AppleMusicURL, URLType
from src.utils import check_dep, run_sync, safely_create_task, get_tasks_num, config_outdated
from src.utils import check_dep, run_sync, safely_create_task, config_outdated
class InteractiveShell:
@@ -99,7 +99,7 @@ class InteractiveShell:
return
def bottom_toolbar(self):
return f"Download Speed: {it(SpeedMeasurer).download_speed()}, Decrypt Speed: {it(SpeedMeasurer).decrypt_speed()}, Tasks: {get_tasks_num()-2}"
return f"Download Speed: {it(Measurer).download_speed()}, Decrypt Speed: {it(Measurer).decrypt_speed()}, Tasks: {it(Measurer).tasks_count()}"
def completer(self):
mycompleter = {
@@ -144,7 +144,8 @@ class InteractiveShell:
await self.logout_flow()
elif command.strip() == '':
continue
else: await self.command_parser(command)
else:
await self.command_parser(command)
except (EOFError, KeyboardInterrupt):
return

View File

@@ -5,11 +5,12 @@ from typing import Type
from creart import CreateTargetInfo, AbstractCreator, exists_module
class SpeedMeasurer:
class Measurer:
def __init__(self, sample_window=1):
self._sample_window = sample_window
self._download_records = deque() # 存储 (时间戳, 字节数)
self._decrypt_records = deque() # 存储 (时间戳, 字节数)
self._running_tasks = 0
def record_download(self, content_length: int):
now = time.time()
@@ -19,6 +20,12 @@ class SpeedMeasurer:
now = time.time()
self._decrypt_records.append((now, content_length))
def record_task_start(self):
self._running_tasks += 1
def record_task_finish(self):
self._running_tasks -= 1
def download_speed(self) -> str:
now = time.time()
self._evict_old(self._download_records, now)
@@ -29,6 +36,9 @@ class SpeedMeasurer:
self._evict_old(self._decrypt_records, now)
return self._calc_speed(self._decrypt_records)
def tasks_count(self):
return self._running_tasks
def _evict_old(self, dq, now):
"""只保留采样窗口内的数据"""
while dq and now - dq[0][0] > self._sample_window:
@@ -47,7 +57,7 @@ class SpeedMeasurer:
class MeasurerCreator(AbstractCreator):
targets = (
CreateTargetInfo("src.measurer", "SpeedMeasurer"),
CreateTargetInfo("src.measurer", "Measurer"),
)
@staticmethod
@@ -55,5 +65,5 @@ class MeasurerCreator(AbstractCreator):
return exists_module("src.config")
@staticmethod
def create(create_type: Type[SpeedMeasurer]) -> SpeedMeasurer:
def create(create_type: Type[Measurer]) -> Measurer:
return create_type()

View File

@@ -10,7 +10,7 @@ from src.exceptions import CodecNotFoundException
from src.flags import Flags
from src.grpc.manager import WrapperManager, WrapperManagerException
from src.logger import RipLogger
from src.measurer import SpeedMeasurer
from src.measurer import Measurer
from src.metadata import SongMetadata
from src.models import PlaylistInfo
from src.mp4 import extract_media, extract_song, encapsulate, write_metadata, fix_encapsulate, fix_esds_box, \
@@ -37,10 +37,11 @@ async def task_done(task: Task, status: Status):
if task.parentDone:
await task.parentDone.try_done()
del adam_id_task_mapping[task.adamId]
it(Measurer).record_task_finish()
async def on_decrypt_success(adam_id: str, key: str, sample: bytes, sample_index: int):
it(SpeedMeasurer).record_decrypt(len(sample))
it(Measurer).record_decrypt(len(sample))
safely_create_task(recv_decrypted_sample(adam_id, sample_index, sample))
@@ -89,6 +90,7 @@ async def rip_song(url: Song, codec: str, flags: Flags = Flags(),
adam_id_task_mapping[url.id] = task
task.init_logger()
await task_lock.acquire()
it(Measurer).record_task_start()
# Set Metadata
raw_metadata = await it(WebAPI).get_song_info(task.adamId, url.storefront, flags.language)