Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get val loss in 3.x? #9904

Open
devin-ry opened this issue Mar 8, 2023 · 24 comments
Open

How to get val loss in 3.x? #9904

devin-ry opened this issue Mar 8, 2023 · 24 comments
Assignees
Labels

Comments

@devin-ry
Copy link

devin-ry commented Mar 8, 2023

I have seen Validation Loss During Training #7971.
but there is no workflow in base/default_runtime.py.
my mmdetection version is 3.x

@RangiLyu
Copy link
Member

RangiLyu commented Mar 9, 2023

Sorry, the val loss calculation in version 3.x is not yet supported. We will support it in the next few releases.

@abdksyed
Copy link

abdksyed commented Apr 12, 2023

Any updates on this @RangiLyu ?

@thanujan96
Copy link

Why doesn't mmdetection 3.x include validation loss? Has it been removed for a specific reason? This is a critical feature because, without validation loss, we cannot assess whether a model is overfitting or generalizing. Does mmdetection suggest any alternative methods for addressing this? I'm feeling concerned and confused because I couldn't find anything related to this issue in the documentation.

@abdksyed
Copy link

No, the Validation Loss is not supported yet. The only way to check the overfitting is by looking at mAP scores over trained and validation data.

There is a way to get validation loss, but it's more of a hack by creating hooks in the pipeline.

@tomhruby1
Copy link

Anything more on the hacking with the hooks, you can point me to?

@Vaaaaaalllll
Copy link

no update yet?

@willcray
Copy link

No, the Validation Loss is not supported yet. The only way to check the overfitting is by looking at mAP scores over trained and validation data.

There is a way to get validation loss, but it's more of a hack by creating hooks in the pipeline.

@abdksyed could you share how to log mAP on the training dataset while training?

@abdksyed
Copy link

abdksyed commented Dec 16, 2023

@tomhruby1 @willcray

@HOOKS.register_module()
class FindIoU(Hook):
    def __init__(self, name):
        os.makedirs("bestepochs", exist_ok=True)
        # Some Necessary Variables for me
        self.bestIoU = 0
        self.bestepoch = None
        self.name = name
        self.metric = BinaryJaccardIndex()
        
        # RGB format
        self.CLS2COLOR = {
            1: (228,0,120), # Red
            2: (42, 82, 190), # Blue
            3: (3, 192, 60) # Green
        }
        
        # define our custom x axis metric
        wandb.define_metric("coco/epoch")
        # define which metrics will be plotted against it
        # My OWN Custom Metrics, YOU CAN HAVE YOUR LOSS METRIC HERE
        wandb.define_metric(
          "coco/pGen1IoU", step_metric="coco/epoch", step_sync=False)
        wandb.define_metric(
          "coco/pGen2IoU", step_metric="coco/epoch", step_sync=False)
        wandb.define_metric(
          "coco/meanIoU", step_metric="coco/epoch", step_sync=False)
        
        self.artifact = wandb.Artifact(self.name, type='model')
        
    def after_val(self, runner, **kwargs):
        IoUs = []
        # TO LOAD THE MODEL FROM THE RECENT WEIGHT FILE
        checkpoint_file = runner.work_dir + f"/epoch_{runner.epoch}.pth"
        model = init_detector(runner.cfg, checkpoint_file, device='cuda:0')
        meanIoU = []
        val_file = runner.cfg.val_dataloader.dataset.ann_file
        test_file = runner.cfg.test_dataloader.dataset.ann_file
        for f_type, json_path in zip(['pGen1', 'pGen2'], [val_file, test_file]):
            
            # json_path = f"{data_type}.json"
            coco = COCO(json_path)
            img_dir = f"combined_data"
            cat_ids = coco.getCatIds()
            frames = {}
            for idx, img_data in coco.imgs.items():
                anns_ids = coco.getAnnIds(imgIds=img_data['id'], catIds=cat_ids, iscrowd=None)
                anns = coco.loadAnns(anns_ids)

                truth_mask = coco.annToMask(anns[0])
                for i in range(1,len(anns)):
                    truth_mask = np.maximum(truth_mask,coco.annToMask(anns[i])*1)

                img = f'combined_data/{img_data["file_name"]}'  # or img = mmcv.imread(img), which will only load it once
                # PERFORMING INFERENCE
                result = inference_detector(model, img)
                # outputs = predictor(im)

                pred_mask = np.zeros_like(truth_mask)
                for i in result.pred_instances.masks.type(torch.int8):
                    pred_mask = np.maximum(pred_mask, i.to('cpu').numpy().astype(np.uint8))
                    
                # frame = label2rgb(pred_mask, cv2.imread(img), alpha=0.3, bg_label=0)*255
    
                target = torch.tensor(truth_mask)
                preds = torch.tensor(pred_mask)
            
                intersection_mask = np.logical_and(pred_mask == 1, truth_mask == 1)
                pred_mask[truth_mask == 1] = 2
                pred_mask[intersection_mask] = 3
                # Repeating Channels to make it three channels
                pred_mask = np.tile(pred_mask[..., np.newaxis], (1,1,3))
                
                # red -> Wrong Predicted, blue -> Ground Truth, green -> Correct Predicted
                frame = io.imread(img)
                for color_id in range(1,4):
                    mask = np.where(pred_mask == (color_id,)*3, self.CLS2COLOR[color_id], 0).astype('uint8')
                    frame = cv2.addWeighted(frame, 1.0, mask, 0.5, 0)
                
                frames[img_data["file_name"]] = frame

                IoUs.append(self.metric(preds, target).item())
                
            
            size1,size2,_ = frame.shape
            out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 1, (size2, size1), True)
            # Sorting the frames according to frame number eg: p3_frame_000530..PNG
            for _,i in sorted(frames.items(), key=lambda x: x[0]):
                out_img = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
                out.write(out_img)
            out.release()
            
            # Convert MPV4 codec to libx264 codec
            input_file = 'output.mp4'
            output_file = f_type+'.mp4'
            clip = VideoFileClip(input_file)
            clip.write_videofile(output_file, codec='libx264')

            # Collect all meanIoUs for all Generalization Patients
            meanIoU.append(sum(IoUs)/len(IoUs))
            print(f"IoU: {sum(IoUs)/len(IoUs)}")
            
            # axes are (time, channel, height, width)
            wandb.log({f"{self.name}_{f_type}_epoch_{runner.epoch}": wandb.Video(output_file)})
            
        for IoU, log in zip(meanIoU, ['pGen1', 'pGen2']):
            wandb.log({f'coco/{log}':IoU, 'coco/epoch':runner.epoch})
            
        meanIoU = sum(meanIoU)/len(meanIoU)
        if meanIoU > self.bestIoU:
            self.bestIoU = meanIoU
            self.bestepoch = checkpoint_file

        print(f"meanIoU: {meanIoU}")
        wandb.log({'coco/iou':meanIoU, 'coco/epoch':runner.epoch})
        
        print(f"Saving checkpoint of epoch {runner.epoch} to wandb")
        self.artifact.add_file(checkpoint_file, name=f'epoch_{runner.epoch}.pth')
        # wandb.log_artifact(self.artifact)
    def after_run(self,runner, **kwargs):
        shutil.copy(self.bestepoch, f"bestepochs/{self.name}.pth")
        print(f"Saving best checkpoint to wandb")
        self.artifact.add_file(self.bestepoch, name=f"best.pth")
        wandb.log_artifact(self.artifact)

This was a hook which I implemented for finding IoU values after each epoch. here after_val function is something which will run after validation has been performed. after_run will be called after the entire run. There are similar functions like before_val and so on, which are mentioned in the docs (https://mmdetection.readthedocs.io/en/latest/user_guides/useful_hooks.html#how-to-implement-a-custom-hook)

I was doing inference and getting the mask of the prediction to find the IoU with the ground truth mask, and also create videos of the frames and save them in weights and biases. You can change the logic of the code, but function names and all will be same for you

There is inefficiency, like I am performing inference again on the validation/test data to get IoU whereas, while training it, by default inference is done on validation data to get mAP values and so. I couldn't find how to get results of validation which was already performed, so I had to do inference again.

@willcray
Copy link

willcray commented Dec 18, 2023

@abdksyed thanks for sharing this code. This appears to be a way to compute mask IoU loss over the validation set. You mentioned that there's a way to get the mAP on the train set as well:

No, the Validation Loss is not supported yet. The only way to check the overfitting is by looking at mAP scores over trained and validation data.

There is a way to get validation loss, but it's more of a hack by creating hooks in the pipeline.

It appears that the same approach of a custom hook could be used in the docs link you provided above. Perhaps something with after_train_epoch or something similar?

@abdksyed
Copy link

@willcray

Yes, for train loss, you can use after_train_epoch and for val loss similar.

@tmargary
Copy link

tmargary commented Jan 9, 2024

Does anyone have an example script that gets the validation loss using the hook approach?

@Cindy0725
Copy link

Same question, can anyone share how to get the validation loss using the hook approach?

@EmmaMeeus
Copy link

This would be a very useful feature and would appreciate an update on this @RangiLyu .

@Roger-F
Copy link

Roger-F commented Apr 4, 2024

Any update on the feature?

@g824718114
Copy link

Does anyone have an example script that gets the validation loss using the hook approach?

this may be useful
#11331 (comment)

@Ileal16
Copy link

Ileal16 commented Aug 20, 2024

The loss computation seems to be more on the mmengine side.

#open-mmlab/mmengine/issues/1486

@willcray
Copy link

willcray commented Oct 8, 2024

@Ileal16 this is now included in v0.10.5 of MMengine.

However, as explained here, MMDetection will need to update the forward functions of its models to also return the loss when using the predict mode.

You can workaround this by overriding creating a custom model (likely by overriding whichever MMDet model you're using) and appending the loss to the end of its predict mode in its forward function. This is illustrated in the MMEngine PR here.

Upgrading MMengine to latest and addign this workaround should enable validation loss without a second forward pass. Ideally, MMDetection would add this to all of its default models, so this feature is available out of the box.

@devin-ry
Copy link
Author

devin-ry commented Oct 9, 2024

Thank you all for your attention to this issue. Recently, I found that mmcls 0.25.0 involves validation loss. In short, it's about putting model.train() in model.eval(), and then calculating the loss as in forward_train.

@devin-ry
Copy link
Author

devin-ry commented Oct 11, 2024

hey, guys. I did it. Perhaps my implementation is quite simple and immature, but I hope it can serve as a reference for everyone. If you encounter any bugs while using my method, I would appreciate your feedback (I don’t know how to insert images in the text, and the images I have contain Chinese, so I’ll just describe the modifications in text form).

The versions of mmdetection and other related libraries I’m using are as follows:
• mmdetection version: 3.3 [any version of mmdetection 3.x should work]
• mmengine version: 0.8.3 [I've only tested this version; I’m unsure about other versions]
• mmcv version: 2.1.0 [I have tested mmdetection 3.1 with mmcv 2.0.1, and mmdetection 3.3 with mmcv 2.1.0].

The following instructions assume you are constructing the project from source code.
The following instructions assume you are constructing the project from source code.
The following instructions assume you are constructing the project from source code.

Core Idea:

  1. Add loss calculation in the run_iter part of ValLoop.
  2. Create a new hook that overrides after_val_iter and after_val_epoch to calculate validation loss.

→1. Adding Loss Calculation in the run_iter Part of ValLoop:
→→a) Calculating the mini-batch validation loss:
→→→Path: mmengine/mmengine/model/base_model/base_model.py
→→→Line: 132-133 (def val_step ← class BaseModel)
→→→[1]Code after modification:
→→→→data = self.data_preprocessor(data, False)
→→→→# return self._run_forward(data, mode='predict') # type: ignore
→→→→outputs = {'predict_output':self._run_forward(data, mode='predict'),
→→→→'loss':self.parse_losses(self._run_forward(data, mode='loss'))}
→→→→return outputs # type: ignore
→→→[2]Explanation of the code:
→→→→·self.data_preprocessor: Preprocessing the data for prediction.
→→→→·self._run_forward: Runs the forward pass based on the mode. In this case, the mode is set to predict, so loss is not
→→→→calculated. To obtain the loss without affecting the original return values, we can add another mode called loss.
→→b) Post-processing the mini-batch validation loss:
→→→Path: mmengine/mmengine/runner/loops.py
→→→Line: 382-384 (def run_iter ← class ValLoop)
→→→[1]Code after modification:
→→→→with autocast(enabled=self.fp16):
→→→→outputs = self.runner.model.val_step(data_batch)
→→→→# self.evaluator.process(data_samples=outputs, data_batch=data_batch)
→→→→self.evaluator.process(data_samples=outputs['predict_output'],data_batch=data_batch)
→→→[2] Explanation of the code:
→→→→·self.runner.model.val_step: Obtains the mini-batch validation loss.
→→→→·self.evaluator.process: Handles the post-processing of mini-batch predictions. Since we modified outputs in step
→→→→[1] to return both the prediction and the loss as a dictionary, we use outputs['predict_output'] here to replace the original output.
→2.Creating a New Hook to Override after_val_iter and after_val_epoch for Validation Loss Calculation:
I will analyze this part later. This part is quite simple, and for after_val_epoch, I’ll reference the loss calculation from after_train_iter. (The content above was translated into English by GPT.)"

@devin-ry
Copy link
Author

devin-ry commented Oct 11, 2024

  1. Creating a New Hook to Override after_val_iter and after_val_epoch for Validation Loss Calculation:
    Path Directory: mmdet/engine/hooks
    You need to create a new hook script under this folder. You can name the script val_loss.py (I'll directly show the contents I’ve added):
    ===============================================================================
@HOOKS.register_module()
class ValLoss(Hook):
    "Save and print valid loss info"

    def __init__(self,
                 loss_list=[]) -> None:
        self.loss_list = loss_list

    def after_val_epoch(self,
                        runner,
                        metrics: Optional[Dict[str, float]] = None) -> None:
        """
            Figure every loss base self.loss_list and add the output information in logs.
        """
        if len(self.loss_list) > 0:
            loss_log = {}
            for lossInfo in self.loss_list:
                for tmp_loss_name, tmp_loss_value in lossInfo.items():
                    loss_log.setdefault(tmp_loss_name, []).append(tmp_loss_value)
            for loss_name, loss_values in loss_log.items():
                runner.message_hub.update_scalar(f'val/{loss_name}_val', torch.mean(torch.stack(loss_values)))

    def after_val_iter(self,
                       runner: Runner,
                       batch_idx: int,
                       data_batch: Optional[dict] = None,
                       outputs: Optional[dict] = None) -> None:
        """ Save all loss in self.loss_list.

        Args:
            outputs (Dict): the outputs include the outputs of prediction and loss, but we only need loss.
                            All loss will save in self.loss_list.
        """
        if 'loss' in outputs:
            self.loss_list.append(outputs['loss'][1])

================================================================================
When you finish the above steps, don't forget to add the relevant content in init.py.
When you finish the above steps, don't forget to add the relevant content in init.py.
When you finish the above steps, don't forget to add the relevant content in init.py.

@JohannesTheo
Copy link

JohannesTheo commented Oct 31, 2024

Hey everyone, I just came across this thread and want to share what I do :)

TLDR: I create an extra dataloader vor val loss calculations and call a modified model.train_step (without gradient updates) from a hook 'before_train_iter' (because train loss is also before gradient updates). Of course you can run the hook at different hook points as well. This depends on your setting and needs.

Create an extra dataloader in Runner

from itertools import cycle
from mmengine.runner import Runner
from mmengine.registry import RUNNERS

@RUNNERS.register_module()
class CustomRunner(Runner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.val_loss_dl = CustomRunner.build_val_loss_dl(self._train_dataloader, self._val_dataloader)

    @staticmethod
    def build_val_loss_dl(train_dataloader, val_dataloader):
        # have to be unitialized
        assert isinstance(train_dataloader, dict)
        assert isinstance(val_dataloader, dict)
        
        # ensure val dataloader for loss calculation uses same sampler/pipeline as train dl (just switch anns and imgs)
        dl = copy.deepcopy(train_dataloader)
        dl['dataset']['ann_file'] = copy.deepcopy(val_dataloader['dataset']['ann_file'])
        dl['dataset']['data_prefix'] = copy.deepcopy(val_dataloader['dataset']['data_prefix'])
        dl = CustomRunner.build_dataloader(dl)
        return cycle(dl)

In a custom logger hook run train_step without gradient updates

Add to config with custom_hooks = [dict(type=CustomLoggerHook)]

Note that the all_reduce_dict call is optional and that mmdet default logging is only logging loss on rank 0. In my implementation I updated that for train loss as well but you don't have to. Either way, it makes sense to stay consistent so that train and val losses are comparable.

from contextlib import nullcontext
import torch
from mmengine.dist import all_reduce_dict
from mmengine.runner.amp import autocast
from mmengine.hooks import Hook
from mmengine.registry import HOOKS

@HOOKS.register_module()
class CustomLoggerHook(Hook):
    priority = 'BELOW_NORMAL'
    def __init__(self, interval: int = 100):
        self.interval = interval
   
    def before_train_iter(self, 
                          runner,
                          batch_idx: int,
                          data_batch: DATA_BATCH = None) -> None:

        # see base class Hook for different ways to check intervals and change to your needs
        if self.every_n_train_iters(runner=runner, n=self.interval):
            outputs = self._get_loss_on_val_batch(runner)  # losses on val data
            all_reduce_dict(outputs, op='mean')
            val_loss = {}
            for k, v in outputs.items():
                val_loss[k] = v.item()  # cuda to cpu
            # do whatever you want with the loss from here, e.g. log it :)
            # or integrated this in a different hook...

    def _get_loss_on_val_batch(self, runner):        
        # we basically run model.train_step but with:
        #
        #     - model.eval() so we don't update batch_norm stats with val data!!!
        #     - torch.no_grad() so we don't produce gradients
        #     - no optim_wrapper.update_params but we have no gradients anyway
        #     - manual amp context instead of optim_wrapper.optim_context() so we don't change anything inside
        #       the optim_wrapper but still get correct amp loss. Not sure if this is actually necessary but
        #       let's not fiddle with optim_wrapper. Maybe we could simply use the context just without calling
        #       optim_wrapper.update_params() but I'm not sure so we just avoid it.
        #
        # see: https://github.com/open-mmlab/mmengine/blob/main/mmengine/model/base_model/base_model.py#L84
        model = runner.model
        if is_model_wrapper(model):  # unwrap DDP
            model = model.module
        data = next(runner.val_loss_dl)
        amp = hasattr(runner.optim_wrapper, 'cast_dtype')
        cast_dtype =  getattr(runner.optim_wrapper, 'cast_dtype', None)
        model.eval()
        with torch.no_grad():
            with autocast(dtype=cast_dtype) if amp else nullcontext():
                data = model.data_preprocessor(data, True)
                losses = model._run_forward(data, mode='loss')  # type: ignore
                parsed_losses, log_vars = model.parse_losses(losses)  # type: ignore
        model.train()
        return log_vars

@danielsagmeister-cw
Copy link

@JohannesTheo : I would like to use the custom runner in the rtmdet ins config, but where to set the runner_type='CustomRunner'?

@JohannesTheo
Copy link

JohannesTheo commented Nov 4, 2024

Hey @danielsagmeister-cw, sry I forgot to mention but I'm using mmdet 3.3.0. In that case, tools/train.py will instantiate the runner here. You can put the following somewhere in your config (assuming you have a file custom_runner.py and a your_config.cfg right next to it):

from .custom_runner import CustomRunner
runner_type = CustomRunner

or if you are using the text based configs:

# I'm not 100% sure about the correct path etc. but,
# you need the custom_imports to trigger the registry mechanism
custom_imports = dict(imports=['.custom_runner'], allow_failed_imports=False)
runner_type ='CustomRunner'

If you are using mmdet 2.x however, the mechanism will be different. The runner is instantiated here and defined as runner=dict(type=...), for instance here. Note that during the transition of mmdet 2 to 3, a lot of things where moved, also the registry mechanism for Runners, from mmcv to mmengine. Things will be different there and you have to follow the call chain to understand whats going on. My example from above will most likely not work and needs some adjustments.

EDIT: I just checked and in case of mmdet 2.x, you might be able to do almost the same thing but have to extend/inherit from https://github.com/open-mmlab/mmcv/blob/1.x/mmcv/runner/epoch_based_runner.py . I didn't check what's different in terms of Hooks but since the 'EpochBasedRunner' implements the train loop directly, it seems that this part can be customized even easier and more directly, even without a hook.

@devin-ry
Copy link
Author

devin-ry commented Dec 4, 2024

Perhaps it could be easier?
juts like this:

import torch
from typing import Dict, Optional, Union
from mmengine.hooks import Hook
from mmengine.runner import Runner, autocast
from mmdet.registry import HOOKS


@HOOKS.register_module()
class ValLoss(Hook):
    "Save and print valid loss info"

    def __init__(self,
                 loss_list=[]) -> None:
        self.loss_list = loss_list

    def before_val(self, runner) -> None:
        # build the model
        self.model = runner.model

    def after_val_epoch(self,
                         runner,
                         metrics: Optional[Dict[str, float]] = None) -> None:
        """
            Figure every loss base self.loss_list and add the output information in logs.
        """
        if len(self.loss_list) > 0:
            loss_log = {}
            for lossInfo in self.loss_list:
                if 'loss' in lossInfo:
                    for tmp_loss_name, tmp_loss_value in lossInfo.items():
                        loss_log.setdefault(tmp_loss_name, []).append(tmp_loss_value)
            for loss_name, loss_values in loss_log.items():
                runner.message_hub.update_scalar(f'val/{loss_name}_val', torch.mean(torch.stack(loss_values)))
        else:
            print('the model not support valid loss!')

    def after_val_iter(self,
                        runner: Runner,
                        batch_idx: int,
                        data_batch: Union[dict, tuple, list] = None,
                        outputs: Optional[dict] = None) -> None:
        """
        Figure the loss again
        Save all loss in self.loss_list.
        """
        with torch.no_grad():
            with autocast(enabled=runner.val_loop.fp16):
                data = self.model.data_preprocessor(data_batch, True)
                losses = self.model._run_forward(data, mode='loss')  # type: ignore
                self.loss_list.append(losses)

and in my config,I can

custom_hooks = [dict(type='ValLoss'),
                dict(type='TestLoss')]

Although useful, it is inefficient

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests