我有一个函数 dice
def dice(yPred,yTruth,thresh):
smooth = tf.constant(1.0)
threshold = tf.constant(thresh)
yPredThresh = tf.to_float(tf.greater_equal(yPred,threshold))
mul = tf.mul(yPredThresh,yTruth)
intersection = 2*tf.reduce_sum(mul) + smooth
union = tf.reduce_sum(yPredThresh) + tf.reduce_sum(yTruth) + smooth
dice = intersection/union
return dice, yPredThresh
有效的工作。这里给出了一个示例
with tf.Session() as sess:
thresh = 0.5
print("Dice example")
yPred = tf.constant([0.1,0.9,0.7,0.3,0.1,0.1,0.9,0.9,0.1],shape=[3,3])
yTruth = tf.constant([0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0],shape=[3,3])
diceScore, yPredThresh= dice(yPred=yPred,yTruth=yTruth,thresh= thresh)
diceScore_ , yPredThresh_ , yPred_, yTruth_ = sess.run([diceScore,yPredThresh,yPred, yTruth])
print("nScore = {0}".format(diceScore_))
>>> Score = 0.899999976158
我希望能够在骰子的第三次争论中循环。我不知道这样做的最佳方法,以便可以从图中提取它。符合以下内容...
def diceROC(yPred,yTruth,thresholds=np.linspace(0.1,0.9,20)):
thresholds = thresholds.astype(np.float32)
nThreshs = thresholds.size
diceScores = tf.zeros(shape=nThreshs)
for i in xrange(nThreshs):
score,_ = dice(yPred,yTruth,thresholds[i])
diceScores[i] = score
return diceScores
评估diceScoreROC
会产生错误 'Tensor' object does not support item assignment
,因为我无法介入并切片tf张量。
而不是循环,我鼓励您使用TensorFlow的广播能力。如果您将dice
重新定义为:
def dice(yPred,yTruth,thresh):
smooth = tf.constant(1.0)
yPredThresh = tf.to_float(tf.greater_equal(yPred,thresh))
mul = tf.mul(yPredThresh,yTruth)
intersection = 2*tf.reduce_sum(mul, [0, 1]) + smooth
union = tf.reduce_sum(yPredThresh, [0, 1]) + tf.reduce_sum(yTruth, [0, 1]) + smooth
dice = intersection/union
return dice, yPredThresh
您将能够通过3维yPred
和yTruth
(假设张量只是沿最后一个维度重复)和1维thresh
:
with tf.Session() as sess:
thresh = [0.1,0.9,20, 0.5]
print("Dice example")
yPred = tf.constant([0.1,0.9,0.7,0.3,0.1,0.1,0.9,0.9,0.1],shape=[3,3,1])
ypred_tiled = tf.tile(yPred, [1,1,4])
yTruth = tf.constant([0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0],shape=[3,3,1])
ytruth_tiled = tf.tile(yTruth, [1,1,4])
diceScore, yPredThresh= dice(yPred=ypred_tiled,yTruth=ytruth_tiled,thresh= thresh)
diceScore_ = sess.run(diceScore)
print("nScore = {0}".format(diceScore_))
您会得到:
Score = [ 0.73333335 0.77777779 0.16666667 0.89999998]