diff --git a/qa_models/qa_model.py b/qa_models/qa_model.py index ac2b861..b96ea1d 100644 --- a/qa_models/qa_model.py +++ b/qa_models/qa_model.py @@ -9,7 +9,7 @@ import logging from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers.logger import Logger from qa_models.qa_utils import find_offset_index, get_token_offsets, \ find_text_start_end_indices @@ -369,5 +369,5 @@ def log_object(self, data: Dict[str, Any], step): self.logger.info(data) @classmethod - def from_logger(cls, logger: LightningLoggerBase): + def from_logger(cls, logger: Logger): return cls(logger.log_dir)