reader.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import os
  16. import traceback
  17. import six
  18. import sys
  19. if sys.version_info >= (3, 0):
  20. pass
  21. else:
  22. pass
  23. import numpy as np
  24. import paddle
  25. import paddle.nn.functional as F
  26. from copy import deepcopy
  27. from paddle.io import DataLoader, DistributedBatchSampler
  28. from .utils import default_collate_fn
  29. from ppdet.core.workspace import register
  30. from . import transform
  31. from .shm_utils import _get_shared_memory_size_in_M
  32. from ppdet.utils.logger import setup_logger
  33. logger = setup_logger('reader')
  34. MAIN_PID = os.getpid()
  35. class Compose(object):
  36. def __init__(self, transforms, num_classes=80):
  37. self.transforms = transforms
  38. self.transforms_cls = []
  39. for t in self.transforms:
  40. for k, v in t.items():
  41. op_cls = getattr(transform, k)
  42. f = op_cls(**v)
  43. if hasattr(f, 'num_classes'):
  44. f.num_classes = num_classes
  45. self.transforms_cls.append(f)
  46. def __call__(self, data):
  47. for f in self.transforms_cls:
  48. try:
  49. data = f(data)
  50. except Exception as e:
  51. stack_info = traceback.format_exc()
  52. logger.warning("fail to map sample transform [{}] "
  53. "with error: {} and stack:\n{}".format(
  54. f, e, str(stack_info)))
  55. raise e
  56. return data
  57. class BatchCompose(Compose):
  58. def __init__(self, transforms, num_classes=80, collate_batch=True):
  59. super(BatchCompose, self).__init__(transforms, num_classes)
  60. self.collate_batch = collate_batch
  61. def __call__(self, data):
  62. for f in self.transforms_cls:
  63. try:
  64. data = f(data)
  65. except Exception as e:
  66. stack_info = traceback.format_exc()
  67. logger.warning("fail to map batch transform [{}] "
  68. "with error: {} and stack:\n{}".format(
  69. f, e, str(stack_info)))
  70. raise e
  71. # remove keys which is not needed by model
  72. extra_key = ['h', 'w', 'flipped']
  73. for k in extra_key:
  74. for sample in data:
  75. if k in sample:
  76. sample.pop(k)
  77. # batch data, if user-define batch function needed
  78. # use user-defined here
  79. if self.collate_batch:
  80. batch_data = default_collate_fn(data)
  81. else:
  82. batch_data = {}
  83. for k in data[0].keys():
  84. tmp_data = []
  85. for i in range(len(data)):
  86. tmp_data.append(data[i][k])
  87. if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
  88. tmp_data = np.stack(tmp_data, axis=0)
  89. batch_data[k] = tmp_data
  90. return batch_data
  91. class BaseDataLoader(object):
  92. """
  93. Base DataLoader implementation for detection models
  94. Args:
  95. sample_transforms (list): a list of transforms to perform
  96. on each sample
  97. batch_transforms (list): a list of transforms to perform
  98. on batch
  99. batch_size (int): batch size for batch collating, default 1.
  100. shuffle (bool): whether to shuffle samples
  101. drop_last (bool): whether to drop the last incomplete,
  102. default False
  103. num_classes (int): class number of dataset, default 80
  104. collate_batch (bool): whether to collate batch in dataloader.
  105. If set to True, the samples will collate into batch according
  106. to the batch size. Otherwise, the ground-truth will not collate,
  107. which is used when the number of ground-truch is different in
  108. samples.
  109. use_shared_memory (bool): whether to use shared memory to
  110. accelerate data loading, enable this only if you
  111. are sure that the shared memory size of your OS
  112. is larger than memory cost of input datas of model.
  113. Note that shared memory will be automatically
  114. disabled if the shared memory of OS is less than
  115. 1G, which is not enough for detection models.
  116. Default False.
  117. """
  118. def __init__(self,
  119. sample_transforms=[],
  120. batch_transforms=[],
  121. batch_size=1,
  122. shuffle=False,
  123. drop_last=False,
  124. num_classes=80,
  125. collate_batch=True,
  126. use_shared_memory=False,
  127. **kwargs):
  128. # sample transform
  129. self._sample_transforms = Compose(
  130. sample_transforms, num_classes=num_classes)
  131. # batch transfrom
  132. self._batch_transforms = BatchCompose(batch_transforms, num_classes,
  133. collate_batch)
  134. self.batch_size = batch_size
  135. self.shuffle = shuffle
  136. self.drop_last = drop_last
  137. self.use_shared_memory = use_shared_memory
  138. self.kwargs = kwargs
  139. def __call__(self,
  140. dataset,
  141. worker_num,
  142. batch_sampler=None,
  143. return_list=False):
  144. self.dataset = dataset
  145. self.dataset.check_or_download_dataset()
  146. self.dataset.parse_dataset()
  147. # get data
  148. self.dataset.set_transform(self._sample_transforms)
  149. # set kwargs
  150. self.dataset.set_kwargs(**self.kwargs)
  151. # batch sampler
  152. if batch_sampler is None:
  153. self._batch_sampler = DistributedBatchSampler(
  154. self.dataset,
  155. batch_size=self.batch_size,
  156. shuffle=self.shuffle,
  157. drop_last=self.drop_last)
  158. else:
  159. self._batch_sampler = batch_sampler
  160. # DataLoader do not start sub-process in Windows and Mac
  161. # system, do not need to use shared memory
  162. use_shared_memory = self.use_shared_memory and \
  163. sys.platform not in ['win32', 'darwin']
  164. # check whether shared memory size is bigger than 1G(1024M)
  165. if use_shared_memory:
  166. shm_size = _get_shared_memory_size_in_M()
  167. if shm_size is not None and shm_size < 1024.:
  168. logger.warning("Shared memory size is less than 1G, "
  169. "disable shared_memory in DataLoader")
  170. use_shared_memory = False
  171. self.dataloader = DataLoader(
  172. dataset=self.dataset,
  173. batch_sampler=self._batch_sampler,
  174. collate_fn=self._batch_transforms,
  175. num_workers=worker_num,
  176. return_list=return_list,
  177. use_shared_memory=use_shared_memory)
  178. self.loader = iter(self.dataloader)
  179. return self
  180. def __len__(self):
  181. return len(self._batch_sampler)
  182. def __iter__(self):
  183. return self
  184. def __next__(self):
  185. try:
  186. return next(self.loader)
  187. except StopIteration:
  188. self.loader = iter(self.dataloader)
  189. six.reraise(*sys.exc_info())
  190. def next(self):
  191. # python2 compatibility
  192. return self.__next__()
  193. @register
  194. class TrainReader(BaseDataLoader):
  195. __shared__ = ['num_classes']
  196. def __init__(self,
  197. sample_transforms=[],
  198. batch_transforms=[],
  199. batch_size=1,
  200. shuffle=True,
  201. drop_last=True,
  202. num_classes=80,
  203. collate_batch=True,
  204. **kwargs):
  205. super(TrainReader, self).__init__(sample_transforms, batch_transforms,
  206. batch_size, shuffle, drop_last,
  207. num_classes, collate_batch, **kwargs)
  208. @register
  209. class EvalReader(BaseDataLoader):
  210. __shared__ = ['num_classes']
  211. def __init__(self,
  212. sample_transforms=[],
  213. batch_transforms=[],
  214. batch_size=1,
  215. shuffle=False,
  216. drop_last=True,
  217. num_classes=80,
  218. **kwargs):
  219. super(EvalReader, self).__init__(sample_transforms, batch_transforms,
  220. batch_size, shuffle, drop_last,
  221. num_classes, **kwargs)
  222. @register
  223. class TestReader(BaseDataLoader):
  224. __shared__ = ['num_classes']
  225. def __init__(self,
  226. sample_transforms=[],
  227. batch_transforms=[],
  228. batch_size=1,
  229. shuffle=False,
  230. drop_last=False,
  231. num_classes=80,
  232. **kwargs):
  233. super(TestReader, self).__init__(sample_transforms, batch_transforms,
  234. batch_size, shuffle, drop_last,
  235. num_classes, **kwargs)
  236. @register
  237. class EvalMOTReader(BaseDataLoader):
  238. __shared__ = ['num_classes']
  239. def __init__(self,
  240. sample_transforms=[],
  241. batch_transforms=[],
  242. batch_size=1,
  243. shuffle=False,
  244. drop_last=False,
  245. num_classes=1,
  246. **kwargs):
  247. super(EvalMOTReader, self).__init__(sample_transforms, batch_transforms,
  248. batch_size, shuffle, drop_last,
  249. num_classes, **kwargs)
  250. @register
  251. class TestMOTReader(BaseDataLoader):
  252. __shared__ = ['num_classes']
  253. def __init__(self,
  254. sample_transforms=[],
  255. batch_transforms=[],
  256. batch_size=1,
  257. shuffle=False,
  258. drop_last=False,
  259. num_classes=1,
  260. **kwargs):
  261. super(TestMOTReader, self).__init__(sample_transforms, batch_transforms,
  262. batch_size, shuffle, drop_last,
  263. num_classes, **kwargs)
  264. # For Semi-Supervised Object Detection (SSOD)
  265. class Compose_SSOD(object):
  266. def __init__(self, base_transforms, weak_aug, strong_aug, num_classes=80):
  267. self.base_transforms = base_transforms
  268. self.base_transforms_cls = []
  269. for t in self.base_transforms:
  270. for k, v in t.items():
  271. op_cls = getattr(transform, k)
  272. f = op_cls(**v)
  273. if hasattr(f, 'num_classes'):
  274. f.num_classes = num_classes
  275. self.base_transforms_cls.append(f)
  276. self.weak_augs = weak_aug
  277. self.weak_augs_cls = []
  278. for t in self.weak_augs:
  279. for k, v in t.items():
  280. op_cls = getattr(transform, k)
  281. f = op_cls(**v)
  282. if hasattr(f, 'num_classes'):
  283. f.num_classes = num_classes
  284. self.weak_augs_cls.append(f)
  285. self.strong_augs = strong_aug
  286. self.strong_augs_cls = []
  287. for t in self.strong_augs:
  288. for k, v in t.items():
  289. op_cls = getattr(transform, k)
  290. f = op_cls(**v)
  291. if hasattr(f, 'num_classes'):
  292. f.num_classes = num_classes
  293. self.strong_augs_cls.append(f)
  294. def __call__(self, data):
  295. for f in self.base_transforms_cls:
  296. try:
  297. data = f(data)
  298. except Exception as e:
  299. stack_info = traceback.format_exc()
  300. logger.warning("fail to map sample transform [{}] "
  301. "with error: {} and stack:\n{}".format(
  302. f, e, str(stack_info)))
  303. raise e
  304. weak_data = deepcopy(data)
  305. strong_data = deepcopy(data)
  306. for f in self.weak_augs_cls:
  307. try:
  308. weak_data = f(weak_data)
  309. except Exception as e:
  310. stack_info = traceback.format_exc()
  311. logger.warning("fail to map weak aug [{}] "
  312. "with error: {} and stack:\n{}".format(
  313. f, e, str(stack_info)))
  314. raise e
  315. for f in self.strong_augs_cls:
  316. try:
  317. strong_data = f(strong_data)
  318. except Exception as e:
  319. stack_info = traceback.format_exc()
  320. logger.warning("fail to map strong aug [{}] "
  321. "with error: {} and stack:\n{}".format(
  322. f, e, str(stack_info)))
  323. raise e
  324. weak_data['strong_aug'] = strong_data
  325. return weak_data
  326. class BatchCompose_SSOD(Compose):
  327. def __init__(self, transforms, num_classes=80, collate_batch=True):
  328. super(BatchCompose_SSOD, self).__init__(transforms, num_classes)
  329. self.collate_batch = collate_batch
  330. def __call__(self, data):
  331. # split strong_data from data(weak_data)
  332. strong_data = []
  333. for sample in data:
  334. strong_data.append(sample['strong_aug'])
  335. sample.pop('strong_aug')
  336. for f in self.transforms_cls:
  337. try:
  338. data = f(data)
  339. strong_data = f(strong_data)
  340. except Exception as e:
  341. stack_info = traceback.format_exc()
  342. logger.warning("fail to map batch transform [{}] "
  343. "with error: {} and stack:\n{}".format(
  344. f, e, str(stack_info)))
  345. raise e
  346. # remove keys which is not needed by model
  347. extra_key = ['h', 'w', 'flipped']
  348. for k in extra_key:
  349. for sample in data:
  350. if k in sample:
  351. sample.pop(k)
  352. for sample in strong_data:
  353. if k in sample:
  354. sample.pop(k)
  355. # batch data, if user-define batch function needed
  356. # use user-defined here
  357. if self.collate_batch:
  358. batch_data = default_collate_fn(data)
  359. strong_batch_data = default_collate_fn(strong_data)
  360. return batch_data, strong_batch_data
  361. else:
  362. batch_data = {}
  363. for k in data[0].keys():
  364. tmp_data = []
  365. for i in range(len(data)):
  366. tmp_data.append(data[i][k])
  367. if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
  368. tmp_data = np.stack(tmp_data, axis=0)
  369. batch_data[k] = tmp_data
  370. strong_batch_data = {}
  371. for k in strong_data[0].keys():
  372. tmp_data = []
  373. for i in range(len(strong_data)):
  374. tmp_data.append(strong_data[i][k])
  375. if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
  376. tmp_data = np.stack(tmp_data, axis=0)
  377. strong_batch_data[k] = tmp_data
  378. return batch_data, strong_batch_data
  379. class CombineSSODLoader(object):
  380. def __init__(self, label_loader, unlabel_loader):
  381. self.label_loader = label_loader
  382. self.unlabel_loader = unlabel_loader
  383. def __iter__(self):
  384. while True:
  385. try:
  386. label_samples = next(self.label_loader_iter)
  387. except:
  388. self.label_loader_iter = iter(self.label_loader)
  389. label_samples = next(self.label_loader_iter)
  390. try:
  391. unlabel_samples = next(self.unlabel_loader_iter)
  392. except:
  393. self.unlabel_loader_iter = iter(self.unlabel_loader)
  394. unlabel_samples = next(self.unlabel_loader_iter)
  395. yield (
  396. label_samples[0], # sup weak
  397. label_samples[1], # sup strong
  398. unlabel_samples[0], # unsup weak
  399. unlabel_samples[1] # unsup strong
  400. )
  401. def __call__(self):
  402. return self.__iter__()
  403. class BaseSemiDataLoader(object):
  404. def __init__(self,
  405. sample_transforms=[],
  406. weak_aug=[],
  407. strong_aug=[],
  408. sup_batch_transforms=[],
  409. unsup_batch_transforms=[],
  410. sup_batch_size=1,
  411. unsup_batch_size=1,
  412. shuffle=True,
  413. drop_last=True,
  414. num_classes=80,
  415. collate_batch=True,
  416. use_shared_memory=False,
  417. **kwargs):
  418. # sup transforms
  419. self._sample_transforms_label = Compose_SSOD(
  420. sample_transforms, weak_aug, strong_aug, num_classes=num_classes)
  421. self._batch_transforms_label = BatchCompose_SSOD(
  422. sup_batch_transforms, num_classes, collate_batch)
  423. self.batch_size_label = sup_batch_size
  424. # unsup transforms
  425. self._sample_transforms_unlabel = Compose_SSOD(
  426. sample_transforms, weak_aug, strong_aug, num_classes=num_classes)
  427. self._batch_transforms_unlabel = BatchCompose_SSOD(
  428. unsup_batch_transforms, num_classes, collate_batch)
  429. self.batch_size_unlabel = unsup_batch_size
  430. # common
  431. self.shuffle = shuffle
  432. self.drop_last = drop_last
  433. self.use_shared_memory = use_shared_memory
  434. self.kwargs = kwargs
  435. def __call__(self,
  436. dataset_label,
  437. dataset_unlabel,
  438. worker_num,
  439. batch_sampler_label=None,
  440. batch_sampler_unlabel=None,
  441. return_list=False):
  442. # sup dataset
  443. self.dataset_label = dataset_label
  444. self.dataset_label.check_or_download_dataset()
  445. self.dataset_label.parse_dataset()
  446. self.dataset_label.set_transform(self._sample_transforms_label)
  447. self.dataset_label.set_kwargs(**self.kwargs)
  448. if batch_sampler_label is None:
  449. self._batch_sampler_label = DistributedBatchSampler(
  450. self.dataset_label,
  451. batch_size=self.batch_size_label,
  452. shuffle=self.shuffle,
  453. drop_last=self.drop_last)
  454. else:
  455. self._batch_sampler_label = batch_sampler_label
  456. # unsup dataset
  457. self.dataset_unlabel = dataset_unlabel
  458. self.dataset_unlabel.length = self.dataset_label.__len__()
  459. self.dataset_unlabel.check_or_download_dataset()
  460. self.dataset_unlabel.parse_dataset()
  461. self.dataset_unlabel.set_transform(self._sample_transforms_unlabel)
  462. self.dataset_unlabel.set_kwargs(**self.kwargs)
  463. if batch_sampler_unlabel is None:
  464. self._batch_sampler_unlabel = DistributedBatchSampler(
  465. self.dataset_unlabel,
  466. batch_size=self.batch_size_unlabel,
  467. shuffle=self.shuffle,
  468. drop_last=self.drop_last)
  469. else:
  470. self._batch_sampler_unlabel = batch_sampler_unlabel
  471. # DataLoader do not start sub-process in Windows and Mac
  472. # system, do not need to use shared memory
  473. use_shared_memory = self.use_shared_memory and \
  474. sys.platform not in ['win32', 'darwin']
  475. # check whether shared memory size is bigger than 1G(1024M)
  476. if use_shared_memory:
  477. shm_size = _get_shared_memory_size_in_M()
  478. if shm_size is not None and shm_size < 1024.:
  479. logger.warning("Shared memory size is less than 1G, "
  480. "disable shared_memory in DataLoader")
  481. use_shared_memory = False
  482. self.dataloader_label = DataLoader(
  483. dataset=self.dataset_label,
  484. batch_sampler=self._batch_sampler_label,
  485. collate_fn=self._batch_transforms_label,
  486. num_workers=worker_num,
  487. return_list=return_list,
  488. use_shared_memory=use_shared_memory)
  489. self.dataloader_unlabel = DataLoader(
  490. dataset=self.dataset_unlabel,
  491. batch_sampler=self._batch_sampler_unlabel,
  492. collate_fn=self._batch_transforms_unlabel,
  493. num_workers=worker_num,
  494. return_list=return_list,
  495. use_shared_memory=use_shared_memory)
  496. self.dataloader = CombineSSODLoader(self.dataloader_label,
  497. self.dataloader_unlabel)
  498. self.loader = iter(self.dataloader)
  499. return self
  500. def __len__(self):
  501. return len(self._batch_sampler_label)
  502. def __iter__(self):
  503. return self
  504. def __next__(self):
  505. return next(self.loader)
  506. def next(self):
  507. # python2 compatibility
  508. return self.__next__()
  509. @register
  510. class SemiTrainReader(BaseSemiDataLoader):
  511. __shared__ = ['num_classes']
  512. def __init__(self,
  513. sample_transforms=[],
  514. weak_aug=[],
  515. strong_aug=[],
  516. sup_batch_transforms=[],
  517. unsup_batch_transforms=[],
  518. sup_batch_size=1,
  519. unsup_batch_size=1,
  520. shuffle=True,
  521. drop_last=True,
  522. num_classes=80,
  523. collate_batch=True,
  524. **kwargs):
  525. super(SemiTrainReader, self).__init__(
  526. sample_transforms, weak_aug, strong_aug, sup_batch_transforms,
  527. unsup_batch_transforms, sup_batch_size, unsup_batch_size, shuffle,
  528. drop_last, num_classes, collate_batch, **kwargs)