extract_batchsize.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import paddle
  2. import numpy as np
  3. import copy
  4. def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
  5. """
  6. """
  7. pos_lists_, pos_masks_, label_lists_ = [], [], []
  8. img_bs = batch_size
  9. ngpu = int(batch_size / img_bs)
  10. img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
  11. pos_lists_split, pos_masks_split, label_lists_split = [], [], []
  12. for i in range(ngpu):
  13. pos_lists_split.append([])
  14. pos_masks_split.append([])
  15. label_lists_split.append([])
  16. for i in range(img_ids.shape[0]):
  17. img_id = img_ids[i]
  18. gpu_id = int(img_id / img_bs)
  19. img_id = img_id % img_bs
  20. pos_list = pos_lists[i].copy()
  21. pos_list[:, 0] = img_id
  22. pos_lists_split[gpu_id].append(pos_list)
  23. pos_masks_split[gpu_id].append(pos_masks[i].copy())
  24. label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
  25. # repeat or delete
  26. for i in range(ngpu):
  27. vp_len = len(pos_lists_split[i])
  28. if vp_len <= tcl_bs:
  29. for j in range(0, tcl_bs - vp_len):
  30. pos_list = pos_lists_split[i][j].copy()
  31. pos_lists_split[i].append(pos_list)
  32. pos_mask = pos_masks_split[i][j].copy()
  33. pos_masks_split[i].append(pos_mask)
  34. label_list = copy.deepcopy(label_lists_split[i][j])
  35. label_lists_split[i].append(label_list)
  36. else:
  37. for j in range(0, vp_len - tcl_bs):
  38. c_len = len(pos_lists_split[i])
  39. pop_id = np.random.permutation(c_len)[0]
  40. pos_lists_split[i].pop(pop_id)
  41. pos_masks_split[i].pop(pop_id)
  42. label_lists_split[i].pop(pop_id)
  43. # merge
  44. for i in range(ngpu):
  45. pos_lists_.extend(pos_lists_split[i])
  46. pos_masks_.extend(pos_masks_split[i])
  47. label_lists_.extend(label_lists_split[i])
  48. return pos_lists_, pos_masks_, label_lists_
  49. def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
  50. pad_num, tcl_bs):
  51. label_list = label_list.numpy()
  52. batch, _, _, _ = label_list.shape
  53. pos_list = pos_list.numpy()
  54. pos_mask = pos_mask.numpy()
  55. pos_list_t = []
  56. pos_mask_t = []
  57. label_list_t = []
  58. for i in range(batch):
  59. for j in range(max_text_nums):
  60. if pos_mask[i, j].any():
  61. pos_list_t.append(pos_list[i][j])
  62. pos_mask_t.append(pos_mask[i][j])
  63. label_list_t.append(label_list[i][j])
  64. pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
  65. label_list_t, tcl_bs)
  66. label = []
  67. tt = [l.tolist() for l in label_list]
  68. for i in range(tcl_bs):
  69. k = 0
  70. for j in range(max_text_length):
  71. if tt[i][j][0] != pad_num:
  72. k += 1
  73. else:
  74. break
  75. label.append(k)
  76. label = paddle.to_tensor(label)
  77. label = paddle.cast(label, dtype='int64')
  78. pos_list = paddle.to_tensor(pos_list)
  79. pos_mask = paddle.to_tensor(pos_mask)
  80. label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
  81. label_list = paddle.cast(label_list, dtype='int32')
  82. return pos_list, pos_mask, label_list, label