vqa_token_relation.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. class VQAReTokenRelation(object):
  15. def __init__(self, **kwargs):
  16. pass
  17. def __call__(self, data):
  18. """
  19. build relations
  20. """
  21. entities = data['entities']
  22. relations = data['relations']
  23. id2label = data.pop('id2label')
  24. empty_entity = data.pop('empty_entity')
  25. entity_id_to_index_map = data.pop('entity_id_to_index_map')
  26. relations = list(set(relations))
  27. relations = [
  28. rel for rel in relations
  29. if rel[0] not in empty_entity and rel[1] not in empty_entity
  30. ]
  31. kv_relations = []
  32. for rel in relations:
  33. pair = [id2label[rel[0]], id2label[rel[1]]]
  34. if pair == ["question", "answer"]:
  35. kv_relations.append({
  36. "head": entity_id_to_index_map[rel[0]],
  37. "tail": entity_id_to_index_map[rel[1]]
  38. })
  39. elif pair == ["answer", "question"]:
  40. kv_relations.append({
  41. "head": entity_id_to_index_map[rel[1]],
  42. "tail": entity_id_to_index_map[rel[0]]
  43. })
  44. else:
  45. continue
  46. relations = sorted(
  47. [{
  48. "head": rel["head"],
  49. "tail": rel["tail"],
  50. "start_index": self.get_relation_span(rel, entities)[0],
  51. "end_index": self.get_relation_span(rel, entities)[1],
  52. } for rel in kv_relations],
  53. key=lambda x: x["head"], )
  54. data['relations'] = relations
  55. return data
  56. def get_relation_span(self, rel, entities):
  57. bound = []
  58. for entity_index in [rel["head"], rel["tail"]]:
  59. bound.append(entities[entity_index]["start"])
  60. bound.append(entities[entity_index]["end"])
  61. return min(bound), max(bound)