1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
| import pytest import torch import ttach as tta
@pytest.mark.parametrize( "transform", [ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 90, 180, 270]), tta.Scale(scales=[1, 2, 4], interpolation="nearest"), tta.Resize(sizes=[(4, 5), (8, 10)], original_size=(4, 5), interpolation="nearest") ], ) def test_aug_deaug_mask(transform): a = torch.arange(20).reshape(1, 1, 4, 5).float() for p in transform.params: aug = transform.apply_aug_image(a, **{transform.pname: p}) deaug = transform.apply_deaug_mask(aug, **{transform.pname: p}) assert aug.shape == deaug.shape, f"Shape mismatch after augmentation-deaugmentation for {transform}" assert torch.allclose(a, deaug), f"Original tensor does not match deaugmented tensor for {transform}"
@pytest.mark.parametrize( "transform", [ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 90, 180, 270]), tta.Scale(scales=[1, 2, 4], interpolation="nearest"), tta.Add(values=[-1, 0, 1, 2]), tta.Multiply(factors=[-1, 0, 1, 2]), tta.FiveCrops(crop_height=3, crop_width=5), tta.Resize(sizes=[(4, 5), (8, 10), (2, 2)], interpolation="nearest") ], ) def test_label_is_same(transform): a = torch.arange(20).reshape(1, 1, 4, 5).float() for p in transform.params: aug = transform.apply_aug_image(a, **{transform.pname: p}) deaug = transform.apply_deaug_label(aug, **{transform.pname: p}) assert torch.equal(a, deaug), f"Original tensor does not match deaugmented tensor for {transform}"
@pytest.mark.parametrize( "transform", [ tta.HorizontalFlip(), tta.VerticalFlip() ], ) def test_flip_keypoints(transform): keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]]) for p in transform.params: aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p}) deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: p}) assert torch.allclose(keypoints, deaug), f"Original keypoints do not match deaugmented keypoints for {transform}"
@pytest.mark.parametrize( "transform", [ tta.Rotate90(angles=[0, 90, 180, 270]) ], ) def test_rotate90_keypoints(transform): keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]]) for p in transform.params: aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p}) deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: -p}) assert torch.allclose(keypoints, deaug), f"Original keypoints do not match deaugmented keypoints for {transform}"
def test_add_transform(): transform = tta.Add(values=[-1, 0, 1]) a = torch.arange(20).reshape(1, 1, 4, 5).float() for p in transform.params: aug = transform.apply_aug_image(a, **{transform.pname: p}) assert torch.allclose(a, aug), f"Original tensor does not match augmented tensor for {transform}"
def test_multiply_transform(): transform = tta.Multiply(factors=[-1, 0, 1]) a = torch.arange(20).reshape(1, 1, 4, 5).float() for p in transform.params: aug = transform.apply_aug_image(a, **{transform.pname: p}) assert torch.allclose(a, aug), f"Original tensor does not match augmented tensor for {transform}"
def test_fivecrop_transform(): transform = tta.FiveCrops(crop_height=1, crop_width=1) a = torch.arange(25).reshape(1, 1, 5, 5).float() output = [0, 20, 24, 4, 12] for i, p in enumerate(transform.params): aug = transform.apply_aug_image(a, **{transform.pname: p}) assert torch.equal(a[0, 0], aug[0, 0]), f"Original tensor does not match augmented tensor for {transform}"
def test_resize_transform(): transform = tta.Resize(sizes=[(10, 10), (5, 5)], original_size=(5, 5)) a = torch.arange(25).reshape(1, 1, 5, 5).float() for i, p in enumerate(transform.params): aug = transform.apply_aug_image(a, **{transform.pname: p}) assert aug.shape[2:] == p, f"Augmented image shape does not match expected shape {p} for {transform}"
|