分类器决策边界的可视化



我已经构建了一个分类器,可以用三个标签正确地对R^2中的六个点进行分类。然而,我正在尝试将分类器使用的决策边界可视化。有什么简单的方法可以用plt命令做到这一点吗?

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt


X = torch.from_numpy(np.array([[-1.3,0.2],[1.7,0.6],[2,3],[0.8,1.4],[0.5,-1],[0.4,-0.3]])).float()
Y = torch.from_numpy(np.array(([0,0,1,1,2,2])))

train_data = torch.utils.data.TensorDataset(X, Y)
test_data = torch.utils.data.TensorDataset(X, Y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=6, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=6, shuffle=True)

myModel = nn.Sequential(*[nn.Linear(2,2), nn.ReLU(), nn.Linear(2,3)])
myLoss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(myModel.parameters(), lr=0.01)
epoch_loss = []
step_loss = []

for epoch in range(1000):
running_loss = 0.0
miniBatch = 0
for x,y in train_loader:
optimizer.zero_grad()
score = myModel0(x)
loss = myLoss(score, y.type(torch.LongTensor))
loss.backward()
optimizer.step()
running_loss += loss.detach().numpy()
miniBatch += 1
step_loss.append(loss.detach().item())
epoch_loss.append(running_loss/len(train_loader))

简单答案:不。

很难回答:有点。

这完全取决于你准备放弃多少信息。显示多维决策边界从来都不是一种简单的方法。

一种选择可以是使用PCA减少数据的维度,然后运行模型。

这是我写了一段时间的小代码,它并不理想,但这是我能想到的最好的。

它不需要任何修改就可以工作。

def decision_plotted(model, X, y, feature_list):
def color_mapping(x):
colors = ['blue', 'red', 'purple', 'brown', 'yellow', 'green', 'darkblue', 'magenta']
return colors[x]
X_sample = X[feature_list].sample(300)
y_sample = y[X_sample.index]
model_plot = clone(model)
model_plot.fit(X_sample, y_sample)
y_color = y_sample.map(color_mapping)
x_grid, y_grid = np.arange(0, 1.01, 0.1), np.arange(0, 1.01, 0.1)
xx_mesh, yy_mesh = np.meshgrid(x_grid, y_grid)
xx, yy = xx_mesh.ravel(), yy_mesh.ravel()
X_grid = pd.DataFrame([xx, yy]).T
zz = model_plot.predict(X_grid)
zz = zz.reshape(xx_mesh.shape)
plt.figure(figsize=(10, 10))
plt.scatter(X_sample.iloc[:, 0], X_sample.iloc[:, 1], color=y_color)
plt.contourf(xx_mesh, yy_mesh, zz, alpha=0.3)
plt.xlabel(feature_list[0]), plt.ylabel(feature_list[1]), plt.title('Decision boundary')
plt.show()

最新更新