split_fight_train_test_dataset.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import glob
  3. import random
  4. import fnmatch
  5. import re
  6. import sys
  7. class_id = {"nofight": 0, "fight": 1}
  8. def get_list(path, key_func=lambda x: x[-11:], rgb_prefix='img_', level=1):
  9. if level == 1:
  10. frame_folders = glob.glob(os.path.join(path, '*'))
  11. elif level == 2:
  12. frame_folders = glob.glob(os.path.join(path, '*', '*'))
  13. else:
  14. raise ValueError('level can be only 1 or 2')
  15. def count_files(directory):
  16. lst = os.listdir(directory)
  17. cnt = len(fnmatch.filter(lst, rgb_prefix + '*'))
  18. return cnt
  19. # check RGB
  20. video_dict = {}
  21. for f in frame_folders:
  22. cnt = count_files(f)
  23. k = key_func(f)
  24. if level == 2:
  25. k = k.split("/")[0]
  26. video_dict[f] = str(cnt) + " " + str(class_id[k])
  27. return video_dict
  28. def fight_splits(video_dict, train_percent=0.8):
  29. videos = list(video_dict.keys())
  30. train_num = int(len(videos) * train_percent)
  31. train_list = []
  32. val_list = []
  33. random.shuffle(videos)
  34. for i in range(train_num):
  35. train_list.append(videos[i] + " " + str(video_dict[videos[i]]))
  36. for i in range(train_num, len(videos)):
  37. val_list.append(videos[i] + " " + str(video_dict[videos[i]]))
  38. print("train:", len(train_list), ",val:", len(val_list))
  39. with open("fight_train_list.txt", "w") as f:
  40. for item in train_list:
  41. f.write(item + "\n")
  42. with open("fight_val_list.txt", "w") as f:
  43. for item in val_list:
  44. f.write(item + "\n")
  45. if __name__ == "__main__":
  46. frame_dir = sys.argv[1] # "rawframes"
  47. level = sys.argv[2] # 2
  48. train_percent = sys.argv[3] # 0.8
  49. if level == 2:
  50. def key_func(x):
  51. return '/'.join(x.split('/')[-2:])
  52. else:
  53. def key_func(x):
  54. return x.split('/')[-1]
  55. video_dict = get_list(frame_dir, key_func=key_func, level=level)
  56. print("number:", len(video_dict))
  57. fight_splits(video_dict, train_percent)