123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- import os
- from .base_logger import BaseLogger
- class WandbLogger(BaseLogger):
- def __init__(self,
- project=None,
- name=None,
- id=None,
- entity=None,
- save_dir=None,
- config=None,
- **kwargs):
- try:
- import wandb
- self.wandb = wandb
- except ModuleNotFoundError:
- raise ModuleNotFoundError(
- "Please install wandb using `pip install wandb`"
- )
- self.project = project
- self.name = name
- self.id = id
- self.save_dir = save_dir
- self.config = config
- self.kwargs = kwargs
- self.entity = entity
- self._run = None
- self._wandb_init = dict(
- project=self.project,
- name=self.name,
- id=self.id,
- entity=self.entity,
- dir=self.save_dir,
- resume="allow"
- )
- self._wandb_init.update(**kwargs)
- _ = self.run
- if self.config:
- self.run.config.update(self.config)
- @property
- def run(self):
- if self._run is None:
- if self.wandb.run is not None:
- logger.info(
- "There is a wandb run already in progress "
- "and newly created instances of `WandbLogger` will reuse"
- " this run. If this is not desired, call `wandb.finish()`"
- "before instantiating `WandbLogger`."
- )
- self._run = self.wandb.run
- else:
- self._run = self.wandb.init(**self._wandb_init)
- return self._run
- def log_metrics(self, metrics, prefix=None, step=None):
- if not prefix:
- prefix = ""
- updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
-
- self.run.log(updated_metrics, step=step)
- def log_model(self, is_best, prefix, metadata=None):
- model_path = os.path.join(self.save_dir, prefix + '.pdparams')
- artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
- artifact.add_file(model_path, name="model_ckpt.pdparams")
- aliases = [prefix]
- if is_best:
- aliases.append("best")
- self.run.log_artifact(artifact, aliases=aliases)
- def close(self):
- self.run.finish()
|