从用于创建张量流数据集的原始数据框中提取属性

Extract attributes from the original dataframe used to create a tensorflow dataset

我有以下数据框 df:

             sales
2015-10-05  -0.462626
2015-10-06  -0.540147
2015-10-07  -0.450222
2015-10-08  -0.448672
2015-10-09  -0.451773
... ...
2019-10-16  -0.594413
2019-10-17  -0.620770
2019-10-18  -0.586660
2019-10-19  -0.586660
2019-10-20  -0.671934
11340 rows × 1 columns

我变成了 tf.data.Dataset 像这样:

data = np.array(df)
ds = tf.keras.utils.timeseries_dataset_from_array(
    data=data,
    targets=None,
    sequence_length=4,
    sequence_stride=1,
    shuffle=False,
    batch_size=1,)

数据集给我这样的记录

print(next(iter(ds)))
tf.Tensor(
[[[-0.4626256 ]
  [-0.54014736]
  [-0.4502221 ]
  [-0.44867167]]], shape=(1, 4, 1), dtype=float32)

我用它来训练我的 ML 模型,但是,我需要一种方法来找到与我从数据集中获取的值相对应的日期。使用从上面的数据集中获取的示例,我想找到与那些连续值对应的日期,从我们可以看到的数据框中是 [2015-10-05, 2015-10-06, 2015-10-07, 2015-10-08]。理想情况下,如果数据框有多个列,我也想获得其他属性。有办法吗?

您可以尝试使用另一个数据集作为查找。这样您就可以根据需要添加更多属性:

import pandas as pd
import numpy as np
import tensorflow as tf

df = pd.DataFrame(data={'date': ['2015-10-05', '2015-10-06', '2015-10-07', '2015-10-08', '2015-10-09', '2019-10-16', '2019-10-17', '2019-10-18', '2019-10-19', '2019-10-20'],
                        'sales': [-0.462626, -0.540147, -0.450222, -0.448672, -0.451773, -0.594413, -0.620770, -0.586660, -0.586660, -0.671934]})


data = np.array(df['sales'])
ds = tf.keras.utils.timeseries_dataset_from_array(
    data=data,
    targets=None,
    sequence_length=4,
    sequence_stride=1,
    shuffle=False,
    batch_size=1,)

d = tf.data.Dataset.from_tensor_slices((df['date'].to_numpy())).batch(1)
dates = d.flat_map(tf.data.Dataset.from_tensor_slices).window(4, shift=1, stride=1).flat_map(lambda x: x.batch(4)).batch(1)
d = tf.data.Dataset.zip((dates, ds))

def lookup(tensor, dataset):
  dataset = dataset.filter(lambda x, y: tf.reduce_all(tf.equal(y, tensor)))
  return [x.numpy().decode('utf-8') for x in list(dataset.map(lambda x, y: tf.squeeze(x, axis=0)))[0]]

result = lookup(next(iter(ds)), d)
print(result)
['2015-10-05', '2015-10-06', '2015-10-07', '2015-10-08']