我正在无人机镜头上运行一个深度学习分割模型。无人机正交马赛克被切割成256x256px的瓷砖,我正在得到每个瓷砖的预测。为此,我导入了原始的RGB地理标志瓦片,提取了地理信息,并用这些相同的信息创建了一个新的地理标志,但输出的是预测而不是RGB。它运行良好,带有for循环。然而,这需要很长时间。我想用多处理来加快进程。关于如何修改以下代码以实现这一点,有什么帮助吗?
非常感谢
from osgeo import gdal, osr
# Lets export predicted data
pred_dataset = Dataset_pred(x_pred_dir,)
os.listdir(x_pred_dir)
ids = list(range(0,len(os.listdir(x_pred_dir))))
for i in ids:
image = pred_dataset[i]
image = np.expand_dims(image, axis=0)
pr_mask = model.predict(image)
pr_mask = pr_mask.squeeze()
image_name = os.listdir(x_pred_dir)[i]
driver = gdal.GetDriverByName('GTiff')
rows, cols, no_bands = pr_mask.shape
x_geotif_dir = os.path.join(DATA_pred_DIR, 'images_raw/',os.listdir(x_pred_dir)[i])
raster_ds = gdal.Open(x_geotif_dir)
geo_transform=raster_ds.GetGeoTransform()
projection=raster_ds.GetProjection()
image_name = os.listdir(x_pred_dir)[i]
images_fps = os.path.join(y_pred_dir, image_name)
DataSet = driver.Create(images_fps, cols, rows, no_bands, gdal.GDT_Byte)
DataSet.SetGeoTransform(geo_transform)
DataSet.SetProjection(projection)
data = np.moveaxis(pr_mask, -1, 0)
for j, image in enumerate(data, 1):
DataSet.GetRasterBand(j).WriteArray(image)
DataSet = None
类似这样的东西:
import os
import multiprocessing as mp
from osgeo import gdal, osr
def process_image(dataset_image, image_filename):
image = np.expand_dims(dataset_image, axis=0)
pr_mask = model.predict(image)
pr_mask = pr_mask.squeeze()
driver = gdal.GetDriverByName('GTiff')
rows, cols, no_bands = pr_mask.shape
x_geotif_dir = os.path.join(DATA_pred_DIR, 'images_raw/', image_filename)
raster_ds = gdal.Open(x_geotif_dir)
geo_transform=raster_ds.GetGeoTransform()
projection=raster_ds.GetProjection()
images_fps = os.path.join(y_pred_dir, image_filename)
DataSet = driver.Create(images_fps, cols, rows, no_bands, gdal.GDT_Byte)
DataSet.SetGeoTransform(geo_transform)
DataSet.SetProjection(projection)
data = np.moveaxis(pr_mask, -1, 0)
for j, image in enumerate(data, 1):
DataSet.GetRasterBand(j).WriteArray(image)
DataSet = None
def main():
pred_dataset = Dataset_pred(x_pred_dir,)
dir_images = os.listdir(x_pred_dir)
with mp.pool() as pool:
pool.starmap(process_image, zip(pred_dataset, dir_images)
if __name__ == '__main__':
main()