54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
from typing import Any, Dict
|
|
|
|
from lightning.pytorch.utilities import rank_zero_only
|
|
from omegaconf import OmegaConf
|
|
|
|
from matcha.utils import pylogger
|
|
|
|
log = pylogger.get_pylogger(__name__)
|
|
|
|
|
|
@rank_zero_only
|
|
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
|
"""Controls which config parts are saved by Lightning loggers.
|
|
|
|
Additionally saves:
|
|
- Number of model parameters
|
|
|
|
:param object_dict: A dictionary containing the following objects:
|
|
- `"cfg"`: A DictConfig object containing the main config.
|
|
- `"model"`: The Lightning model.
|
|
- `"trainer"`: The Lightning trainer.
|
|
"""
|
|
hparams = {}
|
|
|
|
cfg = OmegaConf.to_container(object_dict["cfg"])
|
|
model = object_dict["model"]
|
|
trainer = object_dict["trainer"]
|
|
|
|
if not trainer.logger:
|
|
log.warning("Logger not found! Skipping hyperparameter logging...")
|
|
return
|
|
|
|
hparams["model"] = cfg["model"]
|
|
|
|
# save number of model parameters
|
|
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
|
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
|
|
|
hparams["data"] = cfg["data"]
|
|
hparams["trainer"] = cfg["trainer"]
|
|
|
|
hparams["callbacks"] = cfg.get("callbacks")
|
|
hparams["extras"] = cfg.get("extras")
|
|
|
|
hparams["task_name"] = cfg.get("task_name")
|
|
hparams["tags"] = cfg.get("tags")
|
|
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
|
hparams["seed"] = cfg.get("seed")
|
|
|
|
# send hparams to all loggers
|
|
for logger in trainer.loggers:
|
|
logger.log_hyperparams(hparams)
|