From ec6da8a44a02fda3fe20a2c236c31cbd05c3d35d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=AA=E5=8D=9C=E3=80=8E=E7=AC=A6=E7=8E=84=E3=80=8F?= Date: Sat, 11 Jan 2025 11:53:53 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0SRAUpdater.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 优化下载逻辑 --- SRAUpdater.py | 149 ++++++++++++++++++++++++++++---------------------- 1 file changed, 85 insertions(+), 64 deletions(-) diff --git a/SRAUpdater.py b/SRAUpdater.py index 022c6a7..e834b97 100644 --- a/SRAUpdater.py +++ b/SRAUpdater.py @@ -25,6 +25,7 @@ import json import os +import sys from dataclasses import dataclass from time import sleep import requests @@ -32,7 +33,33 @@ from requests import RequestException from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn, DownloadColumn, TransferSpeedColumn from functools import lru_cache -from StarRailAssistant.utils import WindowsProcess +import psutil + +FROZEN = getattr(sys, "frozen", False) +""" 是否被打包成了可执行文件 """ + +if FROZEN: + from sys import exit + +try: + from StarRailAssistant.utils import WindowsProcess # type: ignore +except ImportError: + class WindowsProcess: + + @staticmethod + def task_kill(pid): + psutil.Process(pid).kill() + + @staticmethod + def is_process_running(process_name: str): + for p in psutil.process_iter(): + if p.name() == process_name: + return True + return False + + @staticmethod + def Popen(*args, **kwargs) -> psutil.Process: + return psutil.Popen(*args, **kwargs) # 下载进度条 download_progress_bar = Progress( @@ -61,7 +88,9 @@ class Updater: HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36", } - APP_PATH = Path(__file__).parent.parent.absolute() + FROZEN = FROZEN + APP_PATH = Path(sys.executable).parent.absolute() if FROZEN else Path(__file__).parent.absolute() + print(f"当前路径:{APP_PATH}") VERSION_INFO_URL = ( "https://github.com/Shasnow/StarRailAssistant/blob/main/version.json" ) @@ -136,6 +165,10 @@ def version_check(self, v: VersionInfo) -> str: remote_version = version_info["version"] remote_resource_version = version_info["resource_version"] # 比较当前版本和远程版本 + print(f"当前版本:{v.version}") + print(f"当前资源版本:{v.resource_version}") + print(f"远程版本:{remote_version}") + print(f"远程资源版本:{remote_resource_version}") if remote_version > v.version: print(f"发现新版本:{remote_version}") print(f"更新说明:\n{version_info['announcement']}") @@ -147,80 +180,71 @@ def version_check(self, v: VersionInfo) -> str: print("已经是最新版本") return "" - def can_direct(self, url: str) -> bool: - """ - 判断是否可以直连 - :param url: 链接 - :return: 是否可以直连 - """ - try: - response = requests.head(url, headers=self.HEADERS, timeout=5) - if response.status_code == 200: - print("代理直连成功") - return True - except Exception: - pass - return False - - def proxy_avaliable(self, proxy_url: str) -> bool: - """ 判断代理是否可用 """ - result = self.can_direct(f"{proxy_url}/https://github.com/Shasnow/StarRailAssistant/releases/download/v{self.get_current_version().version}/StarRailAssistant_v{self.get_current_version().version}.zip") - if result: - print(f"代理: {proxy_url}, 确认可用, 将使用此代理") - else: - print(f"代理{proxy_url}, 不可用, 跳过.") - return result + def get_download_session(self,url: str, proxy_url: str = "") -> tuple[requests.Session, str]: + _url = f"{proxy_url}{url}" + session = requests.session() + session.headers.update(self.HEADERS) + # head confirm + resp = session.head(_url, allow_redirects=True) + if resp.status_code != 200: + raise RequestException(f"请求{_url}失败,状态码:{resp.status_code}") + return session, _url def _download(self, url: str, filepath: Path, proxy_url: str = "") -> None: """ 下载文件 :param url: 下载链接 - :param filename: 保存文件名 + :param filepath: 保存文件路径 :param proxy_url: 代理链接前缀 """ - filename = filepath.stem - if not self.can_direct(url): - url = f"{proxy_url}{url}" - - start_byte = 0 - resume_header = {} + try: + session, download_url = self.get_download_session(url, proxy_url) + + # 获取文件总大小 + resp = session.head(download_url, allow_redirects=True) + total_size = int(resp.headers.get("Content-Length", 0)) + start_byte = 0 - # 检查文件是否存在 - if filepath.exists(): - start_byte = filepath.stat().st_size + # 设置断点续传头 + resume_header = {} if start_byte > 0: resume_header = {"Range": f"bytes={start_byte}-"} self.HEADERS.update(resume_header) print("服务器支持断点续传,开始继续下载...") - with download_progress_bar as progress: - resp = requests.get(url, headers=self.HEADERS, stream=True) - total_size = int(resp.headers.get("Content-Length", 0)) + # 发起请求 + resp = session.get(download_url, headers=self.HEADERS, stream=True) + # 检查服务器是否支持断点续传 if start_byte > 0 and resp.status_code != 206: print("服务器不支持断点续传,重新下载整个文件") start_byte = 0 - resume_header = {} - self.HEADERS.update(resume_header) - resp = requests.get(url, headers=self.HEADERS, stream=True) - total_size = int(resp.headers.get("Content-Length", 0)) - - task = progress.add_task( - f"[bold blue]下载 {filename}", - filename=filename, - start=False, - total=total_size + start_byte, - completed=start_byte - ) + self.HEADERS.pop("Range", None) # 删除断点续传的header + resp = session.get(download_url, headers=self.HEADERS, stream=True) - mode = 'ab' if start_byte > 0 else 'wb' - with open(filepath, mode) as file: - for data in resp.iter_content(chunk_size=8192): - file.write(data) - progress.update(task, advance=len(data)) - progress.refresh() + # 初始化进度条 + with download_progress_bar as progress: + task = progress.add_task( + "[bold blue]下载中...", + filename=filepath.name.strip(".downloaded"), + start=True, + total=total_size, + completed=start_byte + ) - progress.remove_task(task) + # 打开文件,追加或写入模式 + mode = 'ab' if start_byte > 0 else 'wb' + with open(filepath, mode) as file: + for data in resp.iter_content(chunk_size=8192): + file.write(data) + progress.update(task, advance=len(data)) + progress.refresh() + + progress.remove_task(task) + except RequestException as e: + raise e + finally: + session.close() # 关闭会话 def download(self, download_url: str) -> None: try: @@ -228,18 +252,15 @@ def download(self, download_url: str) -> None: i = 0 while i < self.PROXY_LIST_LEN: proxy_url = self.PROXY[i] - if self.proxy_avaliable(proxy_url): - self._download(download_url, self.DOWNLOAD_FAT, proxy_url) - break - else: - i += 1 - continue + self._download(download_url, self.DOWNLOAD_FAT, proxy_url) + break except Exception as e: print(f"下载更新时出错: {e}") os.system("pause") except KeyboardInterrupt: print("下载更新已取消") - os.remove(self.DOWNLOAD_FAT) + need_remove = input("是否删除下载的部分? (删除后,需要重新下载) (y/n)").strip().lower() + os.remove(self.DOWNLOAD_FAT) if need_remove == "y" else None os.system("pause") exit(1)