Astropy:将FITS表格拆分为训练和测试集



我有一个FITS表,正在用astropy操作。我想将该表随机拆分为训练和测试数据,以创建两个新的FITS表。

我首先想到使用scikit-learn函数test_train_split,但之后我必须将数据来回转换为numpy.array

到目前为止,我已经从FITS文件中读取了astropy.table.tabledata,并尝试了以下

training_fraction = 0.5
n = len(data)
indexes = random.sample(range(n), k=int(n*training_fraction))
testing_sample = data[indexes]
training_sample = ?

但是,我不知道如何获得所有索引不在indexes中的行。也许有更好的方法可以做到这一点?如何获得表的随机分区?


我的表中的每个样本碰巧都有一个唯一的ID,它是一个介于1和len(数据(之间的整数。所以我想,我可以做

indexes = random.sample(range(1, n+1), k=int(n*training_fraction))
testing_sample = data[data['ID'] in indexes]
training_sample = data[data['ID'] not in indexes]

但是第一行提升ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我是如何做到这一点的

training_indexes = sorted(random.sample(range(n), k=int(n*training_fraction)))
testing_indexes = [i for i in range(n) if i not in training_indexes]

testing_sample = data[testing_indexes]
training_sample = data[training_indexes]

但我不知道这是最有效的方式,还是最蟒蛇的方式。

您提到使用scikit-learn中现有的train_test_split路由。如果这是你使用scikit学习的唯一的东西,那就太过分了。但如果你已经在任务的其他部分使用它,你也可以。Astropy表一开始就有Numpy数组支持,所以您不需要"来回转换数据"。

由于表的'ID'列对表中的行进行索引,因此将其正式设置为表的索引会很有用,这样ID值就可以用于为表中的列进行索引(与它们的实际位置索引无关(。例如:

>>> from astropy.table import Table
>>> import numpy as np
>>> t = Table({
...     'ID': [1, 3, 5, 6, 7, 9],
...     'a': np.random.random(6),
...     'b': np.random.random(6)
... })
>>> t
<Table length=6>
ID           a                   b         
int64       float64             float64      
----- ------------------- -------------------
1  0.7285295918917892  0.6180944983953155
3  0.9273855839237182 0.28085439237508925
5  0.8677312765220222  0.5996267567496841
6 0.06182255608446752  0.6604620336092745
7 0.21450048405835265  0.5351066893214822
9   0.928930682667869  0.8178640424254757

然后将'ID'设置为表的索引:

>>> t.add_index('ID')

使用train_test_split对ID进行任意分区:

>>> train_ids, test_ids = train_test_split(t['ID'], test_size=0.2)
>>> train_ids
<Column name='ID' dtype='int64' length=4>
7
9
5
1
>>> test_ids
<Column name='ID' dtype='int64' length=2>
6
3
>>> train_set = t.loc[train_ids]
>>> test_set = t.loc[test_ids]
>>> train_set
<Table length=4>
ID           a                  b         
int64       float64            float64      
----- ------------------- ------------------
7 0.21450048405835265 0.5351066893214822
9   0.928930682667869 0.8178640424254757
5  0.8677312765220222 0.5996267567496841
1  0.7285295918917892 0.6180944983953155
>>> test_set
<Table length=2>
ID           a                   b         
int64       float64             float64      
----- ------------------- -------------------
6 0.06182255608446752  0.6604620336092745
3  0.9273855839237182 0.28085439237508925

(注:

>>> isinstance(t['ID'], np.ndarray)
True
>>> type(t['ID']).__mro__
(astropy.table.column.Column,
astropy.table.column.BaseColumn,
astropy.table._column_mixins._ColumnGetitemShim,
numpy.ndarray,
object)

)

值得一提的是,这可能会帮助你在未来更容易地找到此类问题的答案,因此,更抽象地考虑你正在尝试做什么会有所帮助(似乎你已经这样做了,但问题的措辞表明情况并非如此(:表中的列只是Numpy数组——一旦是这种形式,它们从FITS文件中读取就无关紧要了。在这一点上,你所做的与Astropy也没有直接关系。问题就变成了如何对Numpy数组进行随机分区。

你可以找到这个问题的一般答案,例如,在这个问题中。但是,如果您有train_test_split这样的现有专用实用程序,那么使用它也是很好的

最新更新