# ttach

# 简介

  • 使用 PyTorch 进行图像测试时间增强
  • 与数据增强对训练集的作用类似,测试时增强的目的是对测试图像进行随机修改。因此,我们不会只向训练模型显示一次

1
2
3
4
5
6
7
8
    Input
| # input batch of images
/ / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
\ \ \ / / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output

# test_transforms.py 代码填充分析

# test_transforms.py 原码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import pytest
import torch
import ttach as tta


@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
tta.Rotate90(angles=[0, 90, 180, 270]),
tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
tta.Resize(sizes=[(4, 5), (8, 10)], original_size=(4, 5), interpolation="nearest")
],
)
def test_aug_deaug_mask(transform):
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
deaug = transform.apply_deaug_mask(aug, **{transform.pname: p})
assert torch.allclose(a, deaug)


@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
tta.Rotate90(angles=[0, 90, 180, 270]),
tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
tta.Add(values=[-1, 0, 1, 2]),
tta.Multiply(factors=[-1, 0, 1, 2]),
tta.FiveCrops(crop_height=3, crop_width=5),
tta.Resize(sizes=[(4, 5), (8, 10), (2, 2)], interpolation="nearest")
],
)
def test_label_is_same(transform):
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
deaug = transform.apply_deaug_label(aug, **{transform.pname: p})
assert torch.allclose(aug, deaug)


@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip()
],
)
def test_flip_keypoints(transform):
keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]])
for p in transform.params:
aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p})
deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: p})
assert torch.allclose(keypoints, deaug)


@pytest.mark.parametrize(
"transform",
[
tta.Rotate90(angles=[0, 90, 180, 270])
],
)
def test_rotate90_keypoints(transform):
keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]])
for p in transform.params:
aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p})
deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: -p})
assert torch.allclose(keypoints, deaug)


def test_add_transform():
transform = tta.Add(values=[-1, 0, 1])
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
assert torch.allclose(aug, a + p)


def test_multiply_transform():
transform = tta.Multiply(factors=[-1, 0, 1])
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
assert torch.allclose(aug, a * p)


def test_fivecrop_transform():
transform = tta.FiveCrops(crop_height=1, crop_width=1)
a = torch.arange(25).reshape(1, 1, 5, 5).float()
output = [0, 20, 24, 4, 12]
for i, p in enumerate(transform.params):
aug = transform.apply_aug_image(a, **{transform.pname: p})
assert aug.item() == output[i]


def test_resize_transform():
transform = tta.Resize(sizes=[(10, 10), (5, 5)], original_size=(5, 5))
a = torch.arange(25).reshape(1, 1, 5, 5).float()
for i, p in enumerate(transform.params):
aug = transform.apply_aug_image(a, **{transform.pname: p})
assert aug.item() == output[i]

# 删除 assert 断言给 LLM 分析的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pytest
import torch
import ttach as tta


@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
tta.Rotate90(angles=[0, 90, 180, 270]),
tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
tta.Resize(sizes=[(4, 5), (8, 10)], original_size=(4, 5), interpolation="nearest")
],
)
def test_aug_deaug_mask(transform):
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
deaug = transform.apply_deaug_mask(aug, **{transform.pname: p})


@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
tta.Rotate90(angles=[0, 90, 180, 270]),
tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
tta.Add(values=[-1, 0, 1, 2]),
tta.Multiply(factors=[-1, 0, 1, 2]),
tta.FiveCrops(crop_height=3, crop_width=5),
tta.Resize(sizes=[(4, 5), (8, 10), (2, 2)], interpolation="nearest")
],
)
def test_label_is_same(transform):
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
deaug = transform.apply_deaug_label(aug, **{transform.pname: p})


@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip()
],
)
def test_flip_keypoints(transform):
keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]])
for p in transform.params:
aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p})
deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: p})


@pytest.mark.parametrize(
"transform",
[
tta.Rotate90(angles=[0, 90, 180, 270])
],
)
def test_rotate90_keypoints(transform):
keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]])
for p in transform.params:
aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p})
deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: -p})


def test_add_transform():
transform = tta.Add(values=[-1, 0, 1])
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})


def test_multiply_transform():
transform = tta.Multiply(factors=[-1, 0, 1])
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})


def test_fivecrop_transform():
transform = tta.FiveCrops(crop_height=1, crop_width=1)
a = torch.arange(25).reshape(1, 1, 5, 5).float()
output = [0, 20, 24, 4, 12]
for i, p in enumerate(transform.params):
aug = transform.apply_aug_image(a, **{transform.pname: p})


def test_resize_transform():
transform = tta.Resize(sizes=[(10, 10), (5, 5)], original_size=(5, 5))
a = torch.arange(25).reshape(1, 1, 5, 5).float()
for i, p in enumerate(transform.params):
aug = transform.apply_aug_image(a, **{transform.pname: p})

# LLM 自动填充 assert 断言后的代码

**base prompt 为:** 该 tests case 缺少 assert 断言 请你自动填充它

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pytest
import torch
import ttach as tta

@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
tta.Rotate90(angles=[0, 90, 180, 270]),
tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
tta.Resize(sizes=[(4, 5), (8, 10)], original_size=(4, 5), interpolation="nearest")
],
)
def test_aug_deaug_mask(transform):
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
deaug = transform.apply_deaug_mask(aug, **{transform.pname: p})
assert aug.shape == deaug.shape, f"Shape mismatch after augmentation-deaugmentation for {transform}"
assert torch.allclose(a, deaug), f"Original tensor does not match deaugmented tensor for {transform}"

@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
tta.Rotate90(angles=[0, 90, 180, 270]),
tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
tta.Add(values=[-1, 0, 1, 2]),
tta.Multiply(factors=[-1, 0, 1, 2]),
tta.FiveCrops(crop_height=3, crop_width=5),
tta.Resize(sizes=[(4, 5), (8, 10), (2, 2)], interpolation="nearest")
],
)
def test_label_is_same(transform):
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
deaug = transform.apply_deaug_label(aug, **{transform.pname: p})
assert torch.equal(a, deaug), f"Original tensor does not match deaugmented tensor for {transform}"

@pytest.mark.parametrize(
"transform",
[
tta.HorizontalFlip(),
tta.VerticalFlip()
],
)
def test_flip_keypoints(transform):
keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]])
for p in transform.params:
aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p})
deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: p})
assert torch.allclose(keypoints, deaug), f"Original keypoints do not match deaugmented keypoints for {transform}"

@pytest.mark.parametrize(
"transform",
[
tta.Rotate90(angles=[0, 90, 180, 270])
],
)
def test_rotate90_keypoints(transform):
keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]])
for p in transform.params:
aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p})
deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: -p})
assert torch.allclose(keypoints, deaug), f"Original keypoints do not match deaugmented keypoints for {transform}"

def test_add_transform():
transform = tta.Add(values=[-1, 0, 1])
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
assert torch.allclose(a, aug), f"Original tensor does not match augmented tensor for {transform}"

def test_multiply_transform():
transform = tta.Multiply(factors=[-1, 0, 1])
a = torch.arange(20).reshape(1, 1, 4, 5).float()
for p in transform.params:
aug = transform.apply_aug_image(a, **{transform.pname: p})
assert torch.allclose(a, aug), f"Original tensor does not match augmented tensor for {transform}"

def test_fivecrop_transform():
transform = tta.FiveCrops(crop_height=1, crop_width=1)
a = torch.arange(25).reshape(1, 1, 5, 5).float()
output = [0, 20, 24, 4, 12]
for i, p in enumerate(transform.params):
aug = transform.apply_aug_image(a, **{transform.pname: p})
assert torch.equal(a[0, 0], aug[0, 0]), f"Original tensor does not match augmented tensor for {transform}"

def test_resize_transform():
transform = tta.Resize(sizes=[(10, 10), (5, 5)], original_size=(5, 5))
a = torch.arange(25).reshape(1, 1, 5, 5).float()
for i, p in enumerate(transform.params):
aug = transform.apply_aug_image(a, **{transform.pname: p})
assert aug.shape[2:] == p, f"Augmented image shape does not match expected shape {p} for {transform}"