Tensorflow 2.6.0:如何将一个元素映射到多个元素

Tensorflow 2.6.0: How do I Map One Element into Multiple Elements

我正在尝试制作一个 CNN 来对医学图像进行分类。这些图像很大(~50k x ~30k)。作为我的管道的一部分,我想将图像分成 256 x 256 的补丁。 我想使用 Dataset.map 运算符来执行此操作,以便稍后缓存数据以便于训练。

我发现 解决了 tensorflow 1 中的问题,但我无法将其转换为 tensorflow 2。

对于提出这个问题,我深表歉意,但我能否获得一些转换代码的帮助,以便我可以在 tensorflow 2 中运行它?我有点新手,所以感谢帮助

欢迎在tf.data.Dataset.maptf.data.Dataset.unbatch和官方中使用tf.stack documentation

import tensorflow as tf

some_image_dataset = tf.random.normal(shape=[10, 1024, 768]) 
dataset = tf.data.Dataset.from_tensor_slices(some_image_dataset)

def some_patches_map_func(image):
    return tf.stack([
        image[10 : 10 + 256, 20 : 20 + 256], 
        image[100 : 100 + 256, 100 : 100 + 256], 
        image[500 : 500 + 256, 200 : 200 + 256],
    ]) 

dataset = dataset.map(some_patches_map_func)
dataset = dataset.unbatch().shuffle(10)
dataset = dataset.batch(2) 
    
iterator = iter(dataset)
        
print(next(iterator).shape) # (2, 256, 256)
print(next(iterator).shape) # (2, 256, 256)
print(next(iterator).shape) # (2, 256, 256)