refactor(cli/hfdownload): 重构下载工具,改用SETENV和modelscope命令
1. 移除本地setenvs函数,改用封装好的SETENV任务 2. 替换hf下载命令为modelscope下载命令 3. 优化参数命名和默认下载目录逻辑 4. 简化任务编排代码
This commit is contained in:
@@ -1,82 +1,50 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.tasks.system import SETENV
|
||||
|
||||
HFDownloadType = Literal["model", "dataset", "space"]
|
||||
|
||||
|
||||
def setenvs():
|
||||
"""设置 HuggingFace mirror 环境变量."""
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Download a model from HuggingFace.")
|
||||
parser.add_argument("dataset_name", type=str, help="HuggingFace dataset name.")
|
||||
parser.add_argument("name", help="Target name.")
|
||||
parser.add_argument(
|
||||
"--type",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="dataset",
|
||||
choices=get_args(HFDownloadType),
|
||||
help="HuggingFace dataset type.",
|
||||
"--type", "-t", nargs="?", default="model", choices=get_args(HFDownloadType), help="Target type."
|
||||
)
|
||||
parser.add_argument("--use-hfd", action="store_true", help="Use HFD tool to download dataset.")
|
||||
parser.add_argument("--dir", default=None, help="Download directory.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.dataset_name:
|
||||
parser.error("dataset_name is required")
|
||||
if not args.name:
|
||||
parser.error("name is required")
|
||||
|
||||
dataset_name = args.dataset_name
|
||||
target_name = args.name
|
||||
|
||||
# 创建下载目录
|
||||
download_dir = Path.cwd() / dataset_name
|
||||
if args.dir:
|
||||
download_dir = Path(args.dir)
|
||||
else:
|
||||
download_dir = Path.home() / ".models" / target_name.split("/")[-1]
|
||||
download_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.use_hfd:
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(name="setenvs", fn=setenvs, verbose=True),
|
||||
px.TaskSpec(
|
||||
name="download_hfd",
|
||||
cmd=["wget", "https://hf-mirror.com/hfd/hfd.sh"],
|
||||
depends_on=("setenvs",),
|
||||
verbose=True,
|
||||
),
|
||||
px.TaskSpec(
|
||||
name="chmod_hfd",
|
||||
cmd=["chmod", "a+x", "hfd.sh"],
|
||||
depends_on=("download_hfd",),
|
||||
verbose=True,
|
||||
),
|
||||
px.TaskSpec(
|
||||
name="run_hfd",
|
||||
cmd=["./hfd.sh", dataset_name, args.type],
|
||||
depends_on=("chmod_hfd",),
|
||||
verbose=True,
|
||||
),
|
||||
])
|
||||
else:
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(name="setenvs", fn=setenvs, verbose=True),
|
||||
px.TaskSpec(
|
||||
name="download",
|
||||
cmd=[
|
||||
"uvx",
|
||||
"hf",
|
||||
"download",
|
||||
"--repo-type",
|
||||
args.type,
|
||||
"--force-download",
|
||||
dataset_name,
|
||||
"--local-dir",
|
||||
str(Path.cwd() / dataset_name),
|
||||
],
|
||||
depends_on=("setenvs",),
|
||||
verbose=True,
|
||||
),
|
||||
])
|
||||
graph = px.Graph.from_specs([
|
||||
SETENV("HF_ENDPOINT", "https://hf-mirror.com"),
|
||||
px.TaskSpec(
|
||||
name="download",
|
||||
cmd=[
|
||||
"uvx",
|
||||
"modelscope",
|
||||
"download",
|
||||
f"--{args.type}",
|
||||
target_name,
|
||||
"--local_dir",
|
||||
str(download_dir),
|
||||
],
|
||||
depends_on=("setenv_hf_endpoint",),
|
||||
verbose=True,
|
||||
),
|
||||
])
|
||||
|
||||
px.run(graph, strategy="thread", verbose=True)
|
||||
|
||||
@@ -28,7 +28,7 @@ def SETENV(name: str, value: str, default: bool = False):
|
||||
else:
|
||||
os.environ[name] = value
|
||||
|
||||
return px.TaskSpec(f"set_env_{name}", fn=set_env)
|
||||
return px.TaskSpec(f"setenv_{name.lower()}", fn=set_env, verbose=True)
|
||||
|
||||
|
||||
def WHICH(cmd: str):
|
||||
|
||||
Reference in New Issue
Block a user