wandb_logger.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. from .base_logger import BaseLogger
  3. class WandbLogger(BaseLogger):
  4. def __init__(self,
  5. project=None,
  6. name=None,
  7. id=None,
  8. entity=None,
  9. save_dir=None,
  10. config=None,
  11. **kwargs):
  12. try:
  13. import wandb
  14. self.wandb = wandb
  15. except ModuleNotFoundError:
  16. raise ModuleNotFoundError(
  17. "Please install wandb using `pip install wandb`"
  18. )
  19. self.project = project
  20. self.name = name
  21. self.id = id
  22. self.save_dir = save_dir
  23. self.config = config
  24. self.kwargs = kwargs
  25. self.entity = entity
  26. self._run = None
  27. self._wandb_init = dict(
  28. project=self.project,
  29. name=self.name,
  30. id=self.id,
  31. entity=self.entity,
  32. dir=self.save_dir,
  33. resume="allow"
  34. )
  35. self._wandb_init.update(**kwargs)
  36. _ = self.run
  37. if self.config:
  38. self.run.config.update(self.config)
  39. @property
  40. def run(self):
  41. if self._run is None:
  42. if self.wandb.run is not None:
  43. logger.info(
  44. "There is a wandb run already in progress "
  45. "and newly created instances of `WandbLogger` will reuse"
  46. " this run. If this is not desired, call `wandb.finish()`"
  47. "before instantiating `WandbLogger`."
  48. )
  49. self._run = self.wandb.run
  50. else:
  51. self._run = self.wandb.init(**self._wandb_init)
  52. return self._run
  53. def log_metrics(self, metrics, prefix=None, step=None):
  54. if not prefix:
  55. prefix = ""
  56. updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
  57. self.run.log(updated_metrics, step=step)
  58. def log_model(self, is_best, prefix, metadata=None):
  59. model_path = os.path.join(self.save_dir, prefix + '.pdparams')
  60. artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
  61. artifact.add_file(model_path, name="model_ckpt.pdparams")
  62. aliases = [prefix]
  63. if is_best:
  64. aliases.append("best")
  65. self.run.log_artifact(artifact, aliases=aliases)
  66. def close(self):
  67. self.run.finish()