我是Dask的新手,在使用函数map_blocks
时遇到了一些问题。我正在尝试对2D数组的每个元素执行一个函数。我没有为索引i
和j
创建2个数组,而是创建了大小为i * j
的1个数组。
ij = da.arange(n_users*n_ratings)
diff = da.map_blocks(compute_error, ij, dtype=np.float_).compute()
函数compute_error
:
def compute_error(ij):
i = int(ij/n_users)
j = ij%n_users
if not np.isnan(x[i,j]):
return x[i, j] - np.dot(user_mat[j, :], ratings_mat[:, i])
else:
return 0.0
矩阵x
看起来像:
1 Nan Nan Nan 5 2
Nan Nan Nan Nan 4 Nan
Nan 3 Nan Nan 4 Nan
Nan 3 Nan Nan Nan Nan
矩阵user_mat
(n_usersXnum_latent_features(和ratings_mat
float float float float float float
float float float float float float
float float
float float
float float
float float
我已经阅读了文档并搜索了stackoverlow,但我仍然无法解决以下问题:
KilledWorker Traceback (most recent call last)
<ipython-input-43-e670a6d660ce> in <module>
12 # For each user-offer pair
13 ij = da.arange(n_users*n_offers)
---> 14 diff = da.map_blocks(compute_error, ij, dtype=np.float_).compute()
c:usersappdatalocalprogramspythonpython36libsite-packagesdaskbase.py in compute(self, **kwargs)
281 dask.base.compute
282
--> 283 (result,) = compute(self, traverse=False, **kwargs)
284 return result
285
c:usersappdatalocalprogramspythonpython36libsite-packagesdaskbase.py in compute(*args, **kwargs)
563 postcomputes.append(x.__dask_postcompute__())
564
--> 565 results = schedule(dsk, keys, **kwargs)
566 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
567
c:usersappdatalocalprogramspythonpython36libsite-packagesdistributedclient.py in get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
2652 should_rejoin = False
2653 try:
-> 2654 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
2655 finally:
2656 for f in futures.values():
c:usersappdatalocalprogramspythonpython36libsite-packagesdistributedclient.py in gather(self, futures, errors, direct, asynchronous)
1967 direct=direct,
1968 local_worker=local_worker,
-> 1969 asynchronous=asynchronous,
1970 )
1971
c:usersappdatalocalprogramspythonpython36libsite-packagesdistributedclient.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
836 else:
837 return sync(
--> 838 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
839 )
840
c:usersappdatalocalprogramspythonpython36libsite-packagesdistributedutils.py in sync(loop, func, callback_timeout, *args, **kwargs)
349 if error[0]:
350 typ, exc, tb = error[0]
--> 351 raise exc.with_traceback(tb)
352 else:
353 return result[0]
c:usersappdatalocalprogramspythonpython36libsite-packagesdistributedutils.py in f()
332 if callback_timeout is not None:
333 future = asyncio.wait_for(future, callback_timeout)
--> 334 result[0] = yield future
335 except Exception as exc:
336 error[0] = sys.exc_info()
c:usersappdatalocalprogramspythonpython36libsite-packagestornadogen.py in run(self)
760
761 try:
--> 762 value = future.result()
763 except Exception:
764 exc_info = sys.exc_info()
c:usersappdatalocalprogramspythonpython36libsite-packagesdistributedclient.py in _gather(self, futures, errors, direct, local_worker)
1826 exc = CancelledError(key)
1827 else:
-> 1828 raise exception.with_traceback(traceback)
1829 raise exc
1830 if errors == "skip":
KilledWorker: ("('arange-compute_error-71748aa3c524bc2a5b920efa05deec65', 2)", <Worker 'tcp://127.0.0.1:50070', name: 0, memory: 0, processing: 4>)
如果有任何更有效的方法来进行这种计算,我也愿意接受建议。
不操作ij数组,而是将其值转换为中的索引源数组,请使用dask对实际的源数组进行操作。它将大大加快。
我将源阵列创建为:
-
要创建x的源(Numpy(数组:
arr = np.array([ [1, np.nan, np.nan, np.nan, 5, 2], [np.nan, np.nan, np.nan, np.nan, 4, np.nan], [np.nan, 3, np.nan, np.nan, 4, np.nan], [np.nan, 3, np.nan, np.nan, np.nan, np.nan] ])
-
x数组(来自arr(:
x = da.from_array(arr, chunks=(2, 3))
(我传递了块,以避免将x创建为单个块数组(。
-
user_mat和ratings_mat:
user_mat = np.arange(1, 13, dtype='float').reshape(6, 2) ratings_mat = np.arange(2, 10, dtype='float').reshape(2, 4)
我将它们创建为Numpy数组,但遵循da操作将它们转换为da数组。
实际操作为:
result = da.where(da.notnull(x), da.subtract(x, da.dot(user_mat, ratings_mat).T), 0).compute()
步骤:
- da.notnull(x(-结果选择标准(减法或零(
- da减法运算(…(-减法运算(第一个结果(
- 0-第二个结果(对于x中的NaN元素(
- da.where(…(-计算内容的配方
- compute((-实际计算
以上数据的结果为:
array([[ -13., 0., 0., 0., -73., -92.],
[ 0., 0., 0., 0., -93., 0.],
[ 0., -41., 0., 0., -112., 0.],
[ 0., -48., 0., 0., 0., 0.]])