瀏覽代碼

[feat] 新增文件随机抓取脚本

WangChao 6 月之前
父節點
當前提交
3f366d904e
共有 1 個文件被更改,包括 112 次插入0 次删除
  1. 112 0
      data_collection/files_fetching.py

+ 112 - 0
data_collection/files_fetching.py

@@ -0,0 +1,112 @@
+import argparse
+import os
+import shutil
+import concurrent.futures
+import random
+
+
+version = "1.0.0"
+
+
+def is_in_whitelist(file_name, whitelist):
+    # 获取文件的后缀名
+    _, ext = os.path.splitext(file_name)
+    # 判断后缀名是否在白名单中
+    return ext.lower() in whitelist
+
+
+def list_files(directory, whitelist, blacklist):
+    # 遍历目录中的文件和文件夹
+    res = []
+    for root, dirs, files in os.walk(directory):
+        # 只输出文件名,不包括文件夹
+        for file in files:
+            if not file.startswith(".") and is_in_whitelist(file, whitelist):
+                file_path = os.path.join(root, file)
+                res.append(file_path)
+
+    return res
+
+
+def check_path(path_to_check):
+    try:
+        if not os.path.exists(path_to_check):
+            os.makedirs(path_to_check)
+    except OSError as e:
+        print(f"发生错误: {e}")
+
+
+def copy_file(src_file, dst_file):
+    if os.path.exists(dst_file):
+        print(f"文件 {dst_file} 已存在")
+    else:
+        try:
+            shutil.copy(src_file, dst_file)
+            print(f"文件已成功从 {src_file} 拷贝到 {dst_file}")
+        except FileNotFoundError:
+            print(f"源文件 {src_file} 不存在")
+        except PermissionError:
+            print(f"没有权限拷贝文件到 {dst_file}")
+        except Exception as e:
+            print(f"发生错误: {e}")
+
+
+def copy_files_concurrently(src_files, dst_dir):
+    with concurrent.futures.ThreadPoolExecutor() as executor:
+        futures = [executor.submit(copy_file, src, os.path.join(dst_dir, os.path.basename(src))) for src in src_files]
+        for future in concurrent.futures.as_completed(futures):
+            try:
+                # 你可以在这里处理返回的结果(如果有的话),但在这个例子中,copy_file没有返回值
+                pass
+            except Exception as exc:
+                print(f'Generated an exception: {exc}')
+
+
+def main():
+    parse = argparse.ArgumentParser("可以随机抓取指定目录下的文件,可指定后缀名,指定要抓取的数量\n")
+    parse.add_argument("-v", "--version", action="version", version=version)
+    parse.add_argument("input", help="输入路径")
+    parse.add_argument("output", help="输出路径")
+    parse.add_argument("--count", type=int, help="随机抓取的文件数量,不设置则抓取所有文件")
+    parse.add_argument("--whitelist", nargs='+', help="文件后缀名白名单,如果设置白名单,则只会选取白名单中的文件格式")
+    parse.add_argument("--blacklist", nargs='+', help="[未实装]文件后缀名黑名单,如果设置黑名单,则会选取过滤黑名单中的文件格式")
+
+    args = parse.parse_args()
+
+    whitelist = []
+    blacklist = []
+    count = -1
+
+    if args.whitelist:
+        for suffix in args.whitelist:
+            if not suffix.startswith("."):
+                whitelist.append("." + suffix)
+            else:
+                whitelist.append(suffix)
+
+    if args.blacklist:
+        for suffix in args.blacklist:
+            if not suffix.startswith("."):
+                blacklist.append("." + suffix)
+            else:
+                blacklist.append(suffix)
+
+    if args.count:
+        count = args.count
+
+    files = list_files(args.input, whitelist, blacklist)
+    files_to_copy = []
+
+    check_path(args.output)
+
+    if count == -1 or len(files) <= count:
+        # 拷贝全部
+        files_to_copy = files
+    else:
+        files_to_copy = random.sample(files, count)
+
+    copy_files_concurrently(files_to_copy, args.output)
+
+
+if __name__ == "__main__":
+    main()