编写策略以生成总大小小于特定值的数组形状
Write strategies to generate array shapes with total size less than certain value
我正在尝试编写一个策略,生成大小为 4 的数组形状和所有小于给定值的 dims 的乘积。(比如 16728)。
这意味着搜索 space 因为它的根位于 (1,1,1,1),4 个叶子为 (16728, 1,1,1),(1,16728,1,1 ), (1,1,16728,1), (1, 1,1,16728)
我使用的代码:
# test_shapes.py
import numpy as np
from hypothesis import settings, HealthCheck, given
from hypothesis.extra.numpy import array_shapes
@settings(max_examples=10000, suppress_health_check=HealthCheck.all())
@given(shape=array_shapes(min_dims=4,max_dims=4,min_side=1,max_side=16728).filter(lambda x: np.prod(x) < 16728))
def test_shape(shape):
print(f"testing shape: {shape}")
性能不够。过滤会导致太多被拒绝的示例,并且随机化不会探索除叶子 (16728, 1, 1, 1) 以外的路径。
pytest test_shapes.py --hypothesis-show-statistics
test_shapes.py::test_shape:
- during generate phase (211.31 seconds):
- Typical runtimes: 0-1 ms, ~ 84% in data generation
- 51 passing examples, 0 failing examples, 99949 invalid examples
- Events:
* 99.95%, Retried draw from array_shapes(max_dims=4, max_side=16728, min_dims=4).filter(lambda x: np.prod(x) < 16728) to satisfy filter
* 99.95%, Aborted test because unable to satisfy array_shapes(max_dims=4, max_side=16728, min_dims=4).filter(lambda x: np.prod(x) < 16728)
- Stopped because settings.max_examples=10000, but < 10% of examples satisfied assumptions
有没有更好的方法来编写假设策略,同样能很好地探索通往其他叶子的路径?
好问题!这是一个非常通用的技巧:我们不使用过滤器,而是确保每个示例都是有效的 by construction:
import numpy as np
from hypothesis import given, strategies as st
@st.composite
def small_shapes(draw, *, ndims=4, max_elems=16728):
# Instead of filtering, we calculate the "remaining cap" if the product
# of our side lengths is to remain <= max_elems. Ensuring this by
# construction is much more efficient than filtering.
shape = []
for _ in range(ndims):
side = draw(st.integers(1, max_elems))
max_elems //= side
shape.append(side)
# However, it *does* bias towards having smaller sides for later
# dimensions, which we correct by shuffling the list.
shuffled = draw(st.permutations(shape))
return tuple(shuffled)
@given(shape=small_shapes())
def test_shape(shape):
print(f"testing shape: {shape}")
assert 1 <= np.prod(shape) <= 16728
“洗牌以消除偏差”步骤也是一个可重复使用的技巧。最后 - 虽然我不需要在这里 - 最好的选择通常是使用建设性的方法来使数据更有可能有效...然后应用过滤器处理剩下的 5-10% 的它没有管理的例子。
我正在尝试编写一个策略,生成大小为 4 的数组形状和所有小于给定值的 dims 的乘积。(比如 16728)。
这意味着搜索 space 因为它的根位于 (1,1,1,1),4 个叶子为 (16728, 1,1,1),(1,16728,1,1 ), (1,1,16728,1), (1, 1,1,16728)
我使用的代码:
# test_shapes.py
import numpy as np
from hypothesis import settings, HealthCheck, given
from hypothesis.extra.numpy import array_shapes
@settings(max_examples=10000, suppress_health_check=HealthCheck.all())
@given(shape=array_shapes(min_dims=4,max_dims=4,min_side=1,max_side=16728).filter(lambda x: np.prod(x) < 16728))
def test_shape(shape):
print(f"testing shape: {shape}")
性能不够。过滤会导致太多被拒绝的示例,并且随机化不会探索除叶子 (16728, 1, 1, 1) 以外的路径。
pytest test_shapes.py --hypothesis-show-statistics
test_shapes.py::test_shape:
- during generate phase (211.31 seconds):
- Typical runtimes: 0-1 ms, ~ 84% in data generation
- 51 passing examples, 0 failing examples, 99949 invalid examples
- Events:
* 99.95%, Retried draw from array_shapes(max_dims=4, max_side=16728, min_dims=4).filter(lambda x: np.prod(x) < 16728) to satisfy filter
* 99.95%, Aborted test because unable to satisfy array_shapes(max_dims=4, max_side=16728, min_dims=4).filter(lambda x: np.prod(x) < 16728)
- Stopped because settings.max_examples=10000, but < 10% of examples satisfied assumptions
有没有更好的方法来编写假设策略,同样能很好地探索通往其他叶子的路径?
好问题!这是一个非常通用的技巧:我们不使用过滤器,而是确保每个示例都是有效的 by construction:
import numpy as np
from hypothesis import given, strategies as st
@st.composite
def small_shapes(draw, *, ndims=4, max_elems=16728):
# Instead of filtering, we calculate the "remaining cap" if the product
# of our side lengths is to remain <= max_elems. Ensuring this by
# construction is much more efficient than filtering.
shape = []
for _ in range(ndims):
side = draw(st.integers(1, max_elems))
max_elems //= side
shape.append(side)
# However, it *does* bias towards having smaller sides for later
# dimensions, which we correct by shuffling the list.
shuffled = draw(st.permutations(shape))
return tuple(shuffled)
@given(shape=small_shapes())
def test_shape(shape):
print(f"testing shape: {shape}")
assert 1 <= np.prod(shape) <= 16728
“洗牌以消除偏差”步骤也是一个可重复使用的技巧。最后 - 虽然我不需要在这里 - 最好的选择通常是使用建设性的方法来使数据更有可能有效...然后应用过滤器处理剩下的 5-10% 的它没有管理的例子。