给定一个多维数据数组,我想确定每列满足每行条件的列。我有一个工作算法,我想进一步优化。虽然我的方法可以接受多个条件,但通过代码审查中的帖子建议的方法却没有。我想调整代码审查中建议的方法以采用多种条件。
例如,请考虑一些示例数据。
import numpy as np
def get_sample_data(nsample):
""" This function returns a multi-dimensional integer array. """
if nsample == 0:
row_a = np.array([1, 4, 7, 3, 10, 3, 5, 1])
row_b = np.array([2, 5, 30, 30, 10, 5, 5, 1])
row_c = np.array([23, 21, 22, 23, 23, 25, 21, 23])
elif nsample == 1:
row_a = np.linspace(1, 10, 10)
row_b = row_a * 10
row_c = row_a + 20
data = np.array([row_a, row_b, row_c])
return data
data = get_sample_data(0)
# data = get_sample_data(1)
我编写了一个函数来帮助简化为每行分配各种条件之一的过程。
def search(condition, value, relation):
""" This function returns the indices at which the array condition is satisfied. """
if relation in ('equality', 'exact match'):
res = np.where(condition == value)
elif relation == 'not equal':
res = np.where(condition != value)
elif relation == 'greater than':
res = np.where(condition > value)
elif relation in ('less than', 'lesser than'):
res = np.where(condition < value)
elif relation == 'greater than or equal':
res = np.where(condition >= value)
elif relation in ('less than or equal', 'lesser than or equal'):
res = np.where(condition <= value)
elif relation == 'nearest':
delta = np.abs(condition - value)
res = np.where(delta == np.min(delta))
elif relation == 'nearest forward':
delta = condition - value
try:
res = np.where(delta == np.min(delta[delta >= 0]))
except:
raise ValueError("no forward-nearest match exists")
elif relation == 'nearest backward':
delta = value - condition
try:
res = np.where(delta == np.min(delta[delta >= 0]))
except:
raise ValueError("no backward-nearest match exists")
elif relation == 'custom':
res = np.where(condition)
else:
raise ValueError("the input search relation is invalid")
return res
下面是我的实现,它工作成功。
def get_base(shape, value, dtype=int):
""" This function returns a basemask, over which values may be overwritten. """
if isinstance(value, (float, int)):
res = np.ones(shape, dtype=dtype) * value
elif isinstance(value, str):
res = np.array([value for idx in range(np.prod(shape))]).reshape(shape)
return res
def alternate_base(shape, key):
""" This function returns one of two basemasks, each consisting of a single broadcast value. """
if key % 2 == 0:
value = 0.25
else:
value = 0.5
return get_base(shape, value, dtype=float)
def my_method(ndata, search_value, search_relation):
""" This method was adapted from a CodeReview and successfully works, but I would like to further optimize it. """
if isinstance(search_relation, str):
search_relation = (search_relation, search_relation, search_relation)
elif len(search_relation) != 3:
raise ValueError("search_relation should be a string or a collection of three relations")
print("nDATA SAMPLE:n{}n".format(ndata))
print("SEARCH VALUE: {}nSEARCH RELATION: {}n".format(search_value, search_relation))
bases = np.array([alternate_base(len(ndata.T), idx) for idx in range(len(ndata))])
locs = np.array([search(condition=ndata[idx], value=search_value[idx], relation=search_relation[idx])[0] for idx in range(len(search_relation))])
for base, loc in zip(bases, locs):
base[loc] = 0
condition = np.sum(bases, axis=0)
idx_res = search(condition=condition, value=0, relation='equality')[0]
val_res = np.array([ndata[idx][idx_res] for idx in range(len(ndata))])
print("RESULTANT INDICES:n{}n".format(idx_res))
print("RESULTANT VALUES:n{}n".format(val_res))
if len(idx_res) == 0:
raise ValueError("match not found for multiple conditions")
return idx_res
上面的方法基于此代码审查进行了轻微的更改。审查中建议的方法如下。但这种方法只涵盖了严格的相等条件(==
)。是否可以使其适应多种条件?
def martin_fabre_method(ndata, search_value):
""" """
print("nNDATA:n{}n".format(ndata))
print("SEARCH VALUE: {}n".format(search_value))
mask = ndata == [[i] for i in search_value]
idx_res = mask.all(axis=0)
if not np.any(idx_res):
raise ValueError("match not found for multiple conditions")
val_res = ndata[:, idx_res]
print("RESULTANT INDICES:n{}n".format(idx_res))
print("RESULTANT VALUES:n{}n".format(val_res))
return idx_res
要运行该算法,可以复制粘贴上述内容并运行以下命令:
# my_method(data, search_value=(7, 30, 22), search_relation='equality')
# my_method(data, search_value=(7, 5, 22), search_relation=('less than', 'equality', 'less than'))
martin_fabre_method(data, search_value=(7, 30, 22))
您可以将我的代码审查答案中的第一行替换为以下内容:
def get_mask(data, search_value, comparison):
comparisons = {
'equal': '__eq__',
'equality': '__eq__',
'exact match': '__eq__',
'greater than': '__gt__',
'greater than or equal': '__ge__',
'less than': '__lt__',
'less than or equal': '__le__',
'lesser than': '__lt__',
'lesser than or equal': '__le__',
'not equal': '__ne__',
}
try:
comp = getattr(data, comparisons[comparison])
return comp(search_value)
except KeyError:
pass
if comparison == 'custom':
return np.where(condition)
delta = data - search_value
if comparison == 'nearest':
delta = np.abs(delta)
elif comparison == 'nearest forward':
delta = np.where(delta >= 0, delta, np.inf).min(axis=1, keepdims=True)
print(min_)
elif comparison == 'nearest backward':
delta = -np.where(delta <= 0, delta, -np.inf)
if (delta == np.inf).all(axis=0).any():
raise ValueError("no %s match exists for searchvalue %s" % (comparison, repr(search_value)))
# print(repr(delta))
# print(min_)
return delta == delta.min(axis=1, keepdims=True)
def martin_fabre_method(ndata, search_value, comparison):
""" """
print("nNDATA:n{}n".format(ndata))
print("SEARCH VALUE: {}n".format(search_value))
mask = get_mask(ndata, search_value, comparison)
idx_res = mask.all(axis=0)
if not np.any(idx_res):
raise ValueError("match not found for multiple conditions")
val_res = ndata[:, idx_res]
print("RESULTANT INDICES:n{}n".format(idx_res))
print("RESULTANT VALUES:n{}n".format(val_res))
return idx_res
operator
的替代方案
使用operator
模块可以更清楚地说明第一部分:
def get_mask(data, search_value, comparison):
import operator
comparisons = {
'equal': operator.eq,
'equality': operator.eq,
'exact match': operator.eq,
'greater than': operator.gt,
'greater than or equal': operator.ge,
'less than': operator.lt,
'less than or equal': operator.le,
'lesser than': operator.lt,
'lesser than or equal': operator.le,
'not equal': operator.ne,
}
try:
return comparisons[comparison](data, search_value)
....