diff --git a/src/api.py b/src/api.py index 4fb342f..8fb9df4 100644 --- a/src/api.py +++ b/src/api.py @@ -12,7 +12,7 @@ from tenacity import retry, retry_if_exception_type, wait_random_exponential, st from src.config import Config from src.logger import GlobalLogger -from src.measurer import SpeedMeasurer +from src.measurer import Measurer from src.models import * @@ -90,7 +90,7 @@ class WebAPI: total = int(response.headers.get("Content-Length") if response.headers.get("Content-Length") else response.headers.get("X-Apple-MS-Content-Length")) async for chunk in response.aiter_bytes(): - it(SpeedMeasurer).record_download(len(chunk)) + it(Measurer).record_download(len(chunk)) result.write(chunk) if len(result.getvalue()) != total: raise httpx.HTTPError diff --git a/src/cmd.py b/src/cmd.py index c0952a0..f0b2ef5 100644 --- a/src/cmd.py +++ b/src/cmd.py @@ -4,18 +4,18 @@ import sys from creart import it from prompt_toolkit import PromptSession -from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.completion import NestedCompleter +from prompt_toolkit.patch_stdout import patch_stdout from src.api import WebAPI from src.config import Config from src.flags import Flags from src.grpc.manager import WrapperManager, WrapperManagerException from src.logger import GlobalLogger -from src.measurer import SpeedMeasurer +from src.measurer import Measurer from src.rip import on_decrypt_success, on_decrypt_failed, rip_song, rip_album, rip_artist, rip_playlist from src.url import AppleMusicURL, URLType -from src.utils import check_dep, run_sync, safely_create_task, get_tasks_num, config_outdated +from src.utils import check_dep, run_sync, safely_create_task, config_outdated class InteractiveShell: @@ -99,7 +99,7 @@ class InteractiveShell: return def bottom_toolbar(self): - return f"Download Speed: {it(SpeedMeasurer).download_speed()}, Decrypt Speed: {it(SpeedMeasurer).decrypt_speed()}, Tasks: {get_tasks_num()-2}" + return f"Download Speed: {it(Measurer).download_speed()}, Decrypt Speed: {it(Measurer).decrypt_speed()}, Tasks: {it(Measurer).tasks_count()}" def completer(self): mycompleter = { @@ -144,7 +144,8 @@ class InteractiveShell: await self.logout_flow() elif command.strip() == '': continue - else: await self.command_parser(command) + else: + await self.command_parser(command) except (EOFError, KeyboardInterrupt): return diff --git a/src/measurer.py b/src/measurer.py index af1d649..af15519 100644 --- a/src/measurer.py +++ b/src/measurer.py @@ -5,11 +5,12 @@ from typing import Type from creart import CreateTargetInfo, AbstractCreator, exists_module -class SpeedMeasurer: +class Measurer: def __init__(self, sample_window=1): self._sample_window = sample_window self._download_records = deque() # 存储 (时间戳, 字节数) self._decrypt_records = deque() # 存储 (时间戳, 字节数) + self._running_tasks = 0 def record_download(self, content_length: int): now = time.time() @@ -19,6 +20,12 @@ class SpeedMeasurer: now = time.time() self._decrypt_records.append((now, content_length)) + def record_task_start(self): + self._running_tasks += 1 + + def record_task_finish(self): + self._running_tasks -= 1 + def download_speed(self) -> str: now = time.time() self._evict_old(self._download_records, now) @@ -29,6 +36,9 @@ class SpeedMeasurer: self._evict_old(self._decrypt_records, now) return self._calc_speed(self._decrypt_records) + def tasks_count(self): + return self._running_tasks + def _evict_old(self, dq, now): """只保留采样窗口内的数据""" while dq and now - dq[0][0] > self._sample_window: @@ -47,7 +57,7 @@ class SpeedMeasurer: class MeasurerCreator(AbstractCreator): targets = ( - CreateTargetInfo("src.measurer", "SpeedMeasurer"), + CreateTargetInfo("src.measurer", "Measurer"), ) @staticmethod @@ -55,5 +65,5 @@ class MeasurerCreator(AbstractCreator): return exists_module("src.config") @staticmethod - def create(create_type: Type[SpeedMeasurer]) -> SpeedMeasurer: + def create(create_type: Type[Measurer]) -> Measurer: return create_type() diff --git a/src/rip.py b/src/rip.py index f2e337d..b1f3b26 100644 --- a/src/rip.py +++ b/src/rip.py @@ -10,7 +10,7 @@ from src.exceptions import CodecNotFoundException from src.flags import Flags from src.grpc.manager import WrapperManager, WrapperManagerException from src.logger import RipLogger -from src.measurer import SpeedMeasurer +from src.measurer import Measurer from src.metadata import SongMetadata from src.models import PlaylistInfo from src.mp4 import extract_media, extract_song, encapsulate, write_metadata, fix_encapsulate, fix_esds_box, \ @@ -37,10 +37,11 @@ async def task_done(task: Task, status: Status): if task.parentDone: await task.parentDone.try_done() del adam_id_task_mapping[task.adamId] + it(Measurer).record_task_finish() async def on_decrypt_success(adam_id: str, key: str, sample: bytes, sample_index: int): - it(SpeedMeasurer).record_decrypt(len(sample)) + it(Measurer).record_decrypt(len(sample)) safely_create_task(recv_decrypted_sample(adam_id, sample_index, sample)) @@ -89,6 +90,7 @@ async def rip_song(url: Song, codec: str, flags: Flags = Flags(), adam_id_task_mapping[url.id] = task task.init_logger() await task_lock.acquire() + it(Measurer).record_task_start() # Set Metadata raw_metadata = await it(WebAPI).get_song_info(task.adamId, url.storefront, flags.language)