import argparse import torch from common import flops_calculation_function if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( "--model-path", type=str, help="Path to models checkpoint (.pth file).", ) args = parser.parse_args() checkpoint = torch.load(args.model_path, map_location="cpu") model = checkpoint["model"] flops = flops_calculation_function(model, torch.ones(1, 3, 480, 480)) print(f"MMACs = {flops}")