Astropy:将 FITS table 拆分为训练集和测试集

Astropy: split a FITS table into a training and testing set

我有一个 FITS table 我正在用 astropy 进行操作。我想将 table 随机分成训练和测试数据,以创建两个新的 FITS table。

我首先想到使用 scikit-learn 函数 test_train_split,但后来我不得不将我的数据来回转换为 numpy.array

到目前为止,我已经从 FITS 文件中读取了 astropy.table.Table data 并尝试了以下操作

training_fraction = 0.5
n = len(data)
indexes = random.sample(range(n), k=int(n*training_fraction))
testing_sample = data[indexes]
training_sample = ?

但是,我不知道如何获取索引不在 indexes 中的所有行。也许有更好的方法来做到这一点?如何获得 Table 的随机分区?


我的 table 中的样本恰好每个都有一个唯一的 ID,它是一个介于 1 和 len(data) 之间的整数。所以我想,我可以做到

indexes = random.sample(range(1, n+1), k=int(n*training_fraction))
testing_sample = data[data['ID'] in indexes]
training_sample = data[data['ID'] not in indexes]

但第一行加注 ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我是如何做到这一点的

training_indexes = sorted(random.sample(range(n), k=int(n*training_fraction)))
testing_indexes = [i for i in range(n) if i not in training_indexes]


testing_sample = data[testing_indexes]
training_sample = data[training_indexes]

但我不知道这是最高效的方式,还是最pythonic的方式。

您提到使用来自 scikit-learn 的现有 train_test_split 路由。如果这是您使用 scikit-learn 的 only 事情,那就太过分了。但是,如果您已经将它用于任务的其他部分,您也可以这样做。 Astropy Tables 已经由 Numpy 数组支持,所以你不需要 "convert your data back and forth".

由于 table 的 'ID' 列索引了 table 中的行,正式将其设置为 [=43] 的 index 会很有用=],以便 ID 值可用于索引 table 中的行(独立于它们的实际位置索引)。例如:

>>> from astropy.table import Table
>>> import numpy as np
>>> t = Table({
...     'ID': [1, 3, 5, 6, 7, 9],
...     'a': np.random.random(6),
...     'b': np.random.random(6)
... })
>>> t
<Table length=6>
  ID           a                   b         
int64       float64             float64      
----- ------------------- -------------------
    1  0.7285295918917892  0.6180944983953155
    3  0.9273855839237182 0.28085439237508925
    5  0.8677312765220222  0.5996267567496841
    6 0.06182255608446752  0.6604620336092745
    7 0.21450048405835265  0.5351066893214822
    9   0.928930682667869  0.8178640424254757

然后设置'ID'作为table的索引:

>>> t.add_index('ID')

使用train_test_split根据需要对 ID 进行分区:

>>> train_ids, test_ids = train_test_split(t['ID'], test_size=0.2)
>>> train_ids
<Column name='ID' dtype='int64' length=4>
7
9
5
1
>>> test_ids
<Column name='ID' dtype='int64' length=2>
6
3
>>> train_set = t.loc[train_ids]
>>> test_set = t.loc[test_ids]
>>> train_set
<Table length=4>
  ID           a                  b         
int64       float64            float64      
----- ------------------- ------------------
    7 0.21450048405835265 0.5351066893214822
    9   0.928930682667869 0.8178640424254757
    5  0.8677312765220222 0.5996267567496841
    1  0.7285295918917892 0.6180944983953155
>>> test_set
<Table length=2>
  ID           a                   b         
int64       float64             float64      
----- ------------------- -------------------
    6 0.06182255608446752  0.6604620336092745
    3  0.9273855839237182 0.28085439237508925

(注:

>>> isinstance(t['ID'], np.ndarray)
True
>>> type(t['ID']).__mro__
(astropy.table.column.Column,
 astropy.table.column.BaseColumn,
 astropy.table._column_mixins._ColumnGetitemShim,
 numpy.ndarray,
 object)

)

就其价值而言,因为它可能会帮助您在未来更轻松地找到此类问题的答案,所以它有助于更​​抽象地考虑您正在尝试做的事情(看来您已经 are 这样做,但你的问题的措辞表明并非如此):你的 table 中的列只是 Numpy 数组——一旦它处于这种形式,它们就与从 FITS 文件中读取无关。你正在做的事情在这一点上也与 Astropy 没有直接关系。问题就变成了如何随机划分一个 Numpy 数组。

您可以找到此问题的通用答案,例如 in this question。但是,如果您有 train_test_split 之类的现有专用实用程序,也很不错。