coco-benchmark / test.py
andreysher's picture
Add README and mac measuring
513aed0
raw
history blame contribute delete
No virus
5.73 kB
import sys
from functools import partial
from typing import Callable
from typing import Dict
from typing import Tuple
from typing import Union
from argparse import Namespace
sys.path.append("vision/references/segmentation")
import presets
import torch
import torch.utils.data
import torchvision
import utils
from torch import nn
from common import flops_calculation_function
from common import NanSafeConfusionMatrix as ConfusionMatrix
from common import get_coco
def get_dataset(args: Namespace, is_train: bool, transform: Callable = None) -> Tuple[torch.utils.data.Dataset, int]:
def sbd(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
def voc(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.VOCSegmentation(*args, **kwargs)
paths = {
"voc": (args.data_path, voc, 21),
"voc_aug": (args.data_path, sbd, 21),
"coco": (args.data_path, get_coco, 21),
"coco_orig": (args.data_path, partial(get_coco, use_orig=True), 81)
}
p, ds_fn, num_classes = paths["coco_orig"]
if transform is None:
transform = presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)
image_set = "train" if is_train else "val"
ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=args.use_v2)
return ds, num_classes
def criterion(inputs: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor:
losses = {}
for name, x in inputs.items():
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
if len(losses) == 1:
return losses["out"]
return losses["out"] + 0.5 * losses["aux"]
def evaluate(
model: torch.nn.Module,
data_loader: torch.utils.data.DataLoader,
device: Union[str, torch.device],
num_classes: int,
criterion: Callable,
) -> Tuple[ConfusionMatrix, float]:
model.eval()
confmat = ConfusionMatrix(num_classes)
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
num_processed_samples = 0
with torch.inference_mode():
for batch_n, (image, target) in enumerate(metric_logger.log_every(data_loader, 100, header)):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
output = output["out"]
confmat.update(target.flatten(), output.argmax(1).flatten())
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
num_processed_samples += image.shape[0]
metric_logger.update(loss=loss.item())
confmat.reduce_from_all_processes()
return confmat, metric_logger.loss.global_avg
def main(args):
if args.backend.lower() != "pil" and not args.use_v2:
# TODO: Support tensor backend in V1?
raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.")
if args.use_v2:
raise ValueError("v2 is only supported for coco dataset for now.")
print(args)
device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
dataset_test, num_classes = get_dataset(args, is_train=False)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
checkpoint = torch.load(args.model_path)
model = checkpoint["model"]
model.to(device)
model_flops = flops_calculation_function(model=model, input_sample=next(iter(data_loader_test))[0].to(device))
print(f"Model Flops: {model_flops}M")
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
confmat, loss = evaluate(
model=model,
data_loader=data_loader_test,
device=device,
num_classes=num_classes,
criterion=criterion,
)
print(confmat)
return
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
"-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
)
parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument(
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
# distributed training parameters
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
parser.add_argument("--model-path", default=None, help="Path to model checkpoint.")
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)