为什么此代码的输出不是二维形式,而是一维形式

  • 本文关键字:一维 二维 代码 输出 python mxnet
  • 更新时间 :
  • 英文 :

from mxnet import nd
n_train, n_test, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
features = nd.random.normal(shape=(n_train + n_test, 1))
poly_features = nd.concat(features, nd.power(features, 2),
                         nd.power(features, 3))
labels = (true_w[0] * poly_features[:, 0] + true_w[1] * poly_features[:, 1] + true_w[2] * poly_features[:, 2] + true_b)
labels += nd.random.normal(scale=0.01, shape=labels.shape)
print(labels[:2])

因为featurespoly_features的形状都是2D NDARRAY,所以我认为该代码的输出是以下形式:

NDArray 2x1 @cpu(0)

,但实际输出形式为

NDArray 2 @cpu(0)

为什么输出不是2D ndarray?

,而 featurespoly_features是2D ndarray,当您计算labels时,您仅使用poly_features的切片,即1d Ndarrays。这是无线电代码分布:

labels = true_w[0] * poly_features[:, 0] # true_w[0] is scalar, poly_features[:, 0] is 1D NDAarray
       + true_w[1] * poly_features[:, 1] # true_w[1] is scalar, poly_features[:, 1] is 1D NDAarray
       + true_w[2] * poly_features[:, 2] # true_w[2] is scalar, poly_features[:, 2] is 1D NDAarray
       + true_b # true_b is scalar

所以,您得到1D数组作为答案。

最新更新