如何在 Tensorflow 中训练模型时使用哈弗正弦函数作为损失函数?



我想训练一个LSTM模型来预测海洋漂浮的位置(纬度,经度(。我尝试使用haversine损失函数,但我不知道如何实现它。

确切地说,我使用 Keras,模型输出的形状为 (batch_size,2(

@tf.function
def haversine(label_true,label_pred):
lon1 = label_true[1]
lat1 = label_true[0]
lon2 = label_pred[1]
lat2 = label_pred[0]
lon1,lat1,lon2,lat2 = map(radians, [lon1,lat1,lon2,lat2])
dlon = lon2 - lon1
dlat = lat2 - lat1
a = sin(dlat/2)**2 + cos(lat1)*cos(lat2)*sin(dlon/2)**2
c = 2 * asin(sqrt(a))
r = 6371   
return c*r*1000

@tf.function
def loss_haversine(y_true,y_pred):
loss =0
total_dis = np.array([haversine(label_true,label_pred) for label_true,label_pred in zip(y_true,y_pred)])
print(total_dis.shape)
return tf.convert_to_tensor(np.sum(total_dis), dtype=tf.float32)
def create_model(lstm_size=32,timesteps=3,features=6,lstm_dropout=0.1):
model = Sequential()
model.add(LSTM(lstm_size,input_shape=(timesteps,features),recurrent_dropout=lstm_dropout,return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(lstm_size,input_shape=(timesteps,features),return_sequences=False))
model.add(Dropout(0.2))
model.add(Dense(2)) 
#model.add(Activation('linear'))
model.compile(loss=loss_haversine,optimizer='adam', metrics=['accuracy']) 
print(model.summary())
return model

训练模型时,出现此错误: "OperatorNotAllowedInGraphError:不允许遍历tf.Tensor:AutoGraph 没有转换此函数。

多谢!

考虑到张量流的工作原理,我的解决方案是通过矩阵运算计算哈尔正弦距离。代码如下:

import tensorflow as tf
RADIUS_KM = 6378.1
def degrees_to_radians(deg):
pi_on_180 = 0.017453292519943295
return deg * pi_on_180
def loss_haversine(observation, prediction):    
obv_rad = tf.map_fn(degrees_to_radians, observation)
prev_rad = tf.map_fn(degrees_to_radians, prediction)
dlon_dlat = obv_rad - prev_rad 
v = dlon_dlat / 2
v = tf.sin(v)
v = v**2
a = v[:,1] + tf.cos(obv_rad[:,1]) * tf.cos(prev_rad[:,1]) * v[:,0] 
c = tf.sqrt(a)
c = 2* tf.math.asin(c)
c = c*RADIUS_KM
final = tf.reduce_sum(c)
#if you're interested in having MAE with the haversine distance in KM
#uncomment the following line
#final = final/tf.dtypes.cast(tf.shape(observation)[0], dtype= tf.float32)
return final

最新更新