在 Dijkstra 算法中将哪种数据类型用作队列?



我正在尝试用Java实现Dijkstra的算法(自学)。我使用维基百科提供的伪代码(链接)。现在算法快结束了,我应该decrease-key v in Q;。我想我应该用BinaryHeap或类似的东西实现Q?在这里使用的正确(内置)数据类型是什么?

private void dijkstra(int source) {
        int[] dist = new int[this.adjacencyMatrix.length];
        int[] previous = new int[this.adjacencyMatrix.length];
        Queue<Integer> q = new LinkedList<Integer>();
        for (int i = 0; i < this.adjacencyMatrix.length; i++) {
            dist[i] = this.INFINITY;
            previous[i] = this.UNDEFINED;
            q.add(i);
        }
        dist[source] = 0;
        while(!q.isEmpty()) {
            // get node with smallest dist;
            int u = 0;
            for(int i = 0; i < this.adjacencyMatrix.length; i++) {
                if(dist[i] < dist[u])
                    u = i;
            }
            // break if dist == INFINITY
            if(dist[u] == this.INFINITY) break;
            // remove u from q
            q.remove(u);
            for(int i = 0; i < this.adjacencyMatrix.length; i++) {
                if(this.adjacencyMatrix[u][i] == 1) {
                    // in a unweighted graph, this.adjacencyMatrix[u][i] always == 1;
                    int alt = dist[u] + this.adjacencyMatrix[u][i]; 
                    if(alt < dist[i]) {
                        dist[i] = alt;
                        previous[i] = u;
                        // here's where I should "decrease the key"
                    }
                }
            }
        }
    }

最简单的方法是使用优先级队列,而不关心优先级队列中先前添加的键。这意味着每个节点将在队列中多次出现,但这丝毫不会影响算法。如果你看了一下,所有被替换的节点版本稍后都会被拾取,到那时,最近的距离已经确定。

维基百科中的if alt < dist[v]:复选框使这项工作发挥了作用。运行时只会因此而降低一点,但如果您需要非常快的版本,则必须进一步优化。

注意:

与任何优化一样,这一优化应该小心处理,可能会导致奇怪且难以发现的错误(例如,请参见此处)。对于大多数情况,只使用移除和重新插入应该是可以的,但我在这里提到的技巧是,如果Dijkstra实现是瓶颈,可以稍微加快代码的速度。

最重要的是:在尝试之前,请确保优先级队列如何处理优先级。队列中的实际优先级永远不应该改变,否则可能会弄乱队列的不变量,这意味着存储在队列中的项目可能不再可检索。例如,在Java中,优先级与对象一起存储,因此您确实需要一个额外的包装器:

这将不起作用:

import java.util.PriorityQueue;
// Store node information and priority together
class Node implements Comparable<Node> {
  public int prio;
  public Node(int prio) { this.prio = prio; }
  public int compareTo(Node n) {
     return Integer.compare(this.prio, n.prio);
  }
}
...
...
PriorityQueue<Node> q = new PriorityQueue<Node>();
n = new Node(10);
q.add(n)
...
// let's update the priority
n.prio = 0;
// re-add
q.add(n);
// q may be broken now

因为在n.prio=0中,您还更改了队列中对象的优先级。然而,这将很好:

import java.util.PriorityQueue;
// Only node information
class Node {
  // Whatever you need for your graph
  public Node() {}
}
class PrioNode {
   public Node n;
   public int prio;
   public PrioNode(Node n, int prio) {
     this.n = n;
     this.prio = prio;
   }
   public int compareTo(PrioNode p) {
      return Integer.compare(this.prio, p.prio);
   }
}
...
...
PriorityQueue<PrioNode> q = new PriorityQueue<PrioNode>();
n = new Node();
q.add(new PrioNode(n,10));
...
// let's update the priority and re-add
q.add(new PrioNode(n,0));
// Everything is fine, because we have not changed the value
// in the queue.

您可以使用TreeSet(在C++中,您可以使用std::set)为Dijkstra实现优先级队列。TreeSet表示一个集合,但我们也允许描述集合中项目的顺序。您需要将节点存储在集合中,并使用节点的距离对节点进行排序。距离最小的节点将位于集合的开头。

class Node {
    public int id;   // store unique id to distinguish elements in the set
    public int dist; // store distance estimates in the Node itself
    public int compareTo(Node other) {
        // TreeSet implements the Comparable interface.
        // We need tell Java how to compare two Nodes in the TreeSet.
        // For Dijstra, we want the node with the _smallest_ distance
        // to be at the front of the set.
        // If two nodes have same distance, choose node with smaller id
        if (this.dist == other.dist) {
            return Integer.compare(this.id, other.id);
        } else {
            return Integer.compare(this.dist, other.dist);
        }
    }
}
// ...
TreeSet<Node> nodes = new TreeSet<Node>();

提取最小操作通过以下方式实现,并花费O(lgn)最坏情况时间:

Node u = nodes.pollFirst();

使用减少键操作,我们移除具有旧键的节点(旧距离估计),并添加具有较小键的新节点(新的、更好的距离估计)。两种操作都需要O(lgn)最坏情况时间。

nodes.remove(v);
v.dist = /* some smaller key */
nodes.add(v);

一些额外的注意事项:

  • 上面的实现非常简单,而且由于这两个操作都是n的对数运算,因此总体而言,运行时间将为O((n+e)lgn)。这被认为对于Dijkstra的一个基本实现是有效的。参见CLRS书籍(ISBN:978-0-262-03384-8)第19章来证明这种复杂性。

  • 尽管大多数教科书都会为Dijkstra、Prim、a*等使用优先级队列,但不幸的是,Java和C++实际上都没有实现具有相同O(lgn)减少键操作的优先级队列!

  • PriorityQueue在Java中确实存在,但remove(Object o)方法是而不是对数的,因此您的递减键运算将是O(n)而不是O(lgn),并且(渐进地)您将获得较慢的Dikjstra!

  • 要从头开始构建TreeSet(使用for循环),需要花费时间O(nlgn),与从n个项目初始化堆/优先级队列的O(n)最坏情况相比,这要慢一些。然而,Dijkstra的主循环花费时间O(nlgn+elgn),这支配了该初始化时间。因此,对于Dijkstra来说,初始化TreeSet不会导致任何显著的放缓。

  • 我们不能使用HashSet,因为我们关心键的顺序——我们希望能够首先拉出最小的键。这为我们提供了具有最佳距离估计的节点!

  • Java中的TreeSet是使用红黑树实现的,这是一种自平衡的二进制搜索树。这就是为什么这些操作具有对数最坏情况时间的原因。

  • 您使用ints来表示图形节点,这很好,但当您引入Node类时,您需要一种方法来关联这两个实体。我建议在构建图形时构建HashMap<Integer, Node>,这将有助于跟踪哪个int对应于哪个Node。`

建议的PriorityQueue不提供减少键操作。但是,可以通过移除元素,然后使用新密钥重新插入元素来模拟它

这不应该增加算法的渐进运行时间,尽管通过内置支持可以使其稍微更快

EDIT:这确实增加了渐进运行时间,因为对于堆,减少键应该是O(log n),但remove(Object)O(n)Java中似乎没有任何内置的优先级队列支持高效的减少键

根据wiki文章的优先级队列。这表明现在的经典实现是使用"由Fibonacci堆实现的最小优先级队列"。

是的,Java没有通过PriorityQueue为最小堆提供递减键,因此删除操作将是O(N),可以优化为logN。

我已经用递减密钥实现了Min Heap(实际上是递减密钥和递增密钥,但这里只有递减密钥就足够了)。实际数据结构是最小堆映射(HashMap存储所有节点的索引,并帮助通过当前顶点更新当前顶点的邻居的最小路径值)

我用泛型实现了优化的解决方案,它花了我大约3-4个小时的时间来编码(我的第一次),时间复杂度是O(logV.E)

希望这会有所帮助!

 package algo;
 import java.util.*;
 public class Dijkstra {
/*
 *
 * @author nalin.sharma
 *
 */
/**
 *
 * Heap Map Data Structure
 * Heap stores the Nodes with their weight based on comparison on Node's weight
 * HashMap stores the Node and its index for O(1) look up of node in heap
 *
 *
 *
 *
 * Example -:
 *
 *                                   7
 *                         [2]----------------->[4]
 *                       ^  |                   ^ 
 *                     /   |                   |    1
 *                 2 /    |                   |     v
 *                 /     |                   |       [6]
 *               /      | 1               2 |       ^
 *             /       |                   |      /
 *          [1]       |                   |     /
 *                  |                   |    / 5
 *            4    |                   |   /
 *               v v                   |  /
 *                [3]---------------->[5]
 *                         3
 *
 *        Minimum distance from source 1
 *         v  | d[v] | path
 *         --- ------  ---------
 *         2 |  2  |  1,2
 *         3 |  3  |  1,2,3
 *         4 |  8  |  1,2,3,5,4
 *         5 |  6  |  1,2,3,5
 *         6 |  9  |  1,2,3,4,6
 *
 *
 *
 *     Below is the Implementation -:
 *
 */
static class HeapMap<T> {
    int size, ind = 0;
    NodeWeight<T> arr [];
    Map<T,Integer> map;
    /**
     *
     * @param t is the Node(1,2,3..or A,B,C.. )
     * @return the index of element in min heap
     */
    int index(T t) {
        return map.get(t);
    }
    /**
     *
     * @param index is the Node(1,2,3..or A,B,C.. )
     * @return Node and its Weight
     */
    NodeWeight<T> node(int index) {
        return arr[index];
    }
    /**
     *
     * @param <T> Node of type <T> and its weight
     */
    static class NodeWeight<T> {
        NodeWeight(T v, int w) {
            nodeVal = v;
            weight = w;
        }
        T nodeVal;
        int weight;
        List<T> path = new ArrayList<>();
    }
    public HeapMap(int s) {
        size = s;
        arr = new NodeWeight[size + 1];
        map = new HashMap<>();
    }
    private void updateIndex(T key, int newInd) {
        map.put(key, newInd);
    }
    private void shiftUp(int i) {
        while(i > 1) {
            int par = i / 2;
            NodeWeight<T> currNodeWeight = arr[i];
            NodeWeight<T> parentNodeWeight = arr[par];
            if(parentNodeWeight.weight > currNodeWeight.weight) {
                updateIndex(parentNodeWeight.nodeVal, i);
                updateIndex(currNodeWeight.nodeVal, par);
                swap(par,i);
                i = i/2;
            }
            else {
                break;
            }
        }
    }
    /**
     *
     * @param nodeVal
     * @param newWeight
     * Based on if the value introduced is higher or lower shift down or shift up operations are performed
     *
     */
    public void update(T nodeVal, int newWeight) {
        int i = ind;
        NodeWeight<T> nodeWeight = arr[map.get(nodeVal)];
        int oldWt = nodeWeight.weight;
        nodeWeight.weight = newWeight;
        if(oldWt < newWeight) {
            shiftDown(map.get(nodeVal));
        }
        else if(oldWt > newWeight) {
            shiftUp(map.get(nodeVal));
        }
    }
    /**
     *
     * @param nodeVal
     * @param wt
     *
     * Typical insertion in Min Heap and storing its element's indices in HashMap for fast lookup
     */
    public void insert(T nodeVal, int wt) {
        NodeWeight<T> nodeWt = new NodeWeight<>(nodeVal, wt);
        arr[++ind] = nodeWt;
        updateIndex(nodeVal, ind);
        shiftUp(ind);
    }
    private void swap(int i, int j) {
        NodeWeight<T> tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }
    private void shiftDown(int i) {
        while(i <= ind) {
            int current = i;
            int lChild = i * 2;
            int rChild = i * 2 + 1;
            if(rChild <= ind) {
                int childInd = (arr[lChild].weight < arr[rChild].weight) ? lChild : rChild;
                if(arr[childInd].weight < arr[i].weight) {
                    updateIndex(arr[childInd].nodeVal, i);
                    updateIndex(arr[i].nodeVal, childInd);
                    swap(childInd, i);
                    i = childInd;
                }
            }
            else if(lChild <= ind && arr[lChild].weight < arr[i].weight) {
                updateIndex(arr[lChild].nodeVal, i);
                updateIndex(arr[i].nodeVal, lChild);
                swap(lChild, i);
                i = lChild;
            }
            if(current == i) {
                break;
            }
        }
    }
    /**
     *
     * @return
     *
     * Typical deletion in Min Heap and removing its element's indices in HashMap
     *
     */
    public NodeWeight<T> remove() {
        if(ind == 0) {
            return null;
        }
        map.remove(arr[1].nodeVal);
        NodeWeight<T> out = arr[1];
        out.path.add(arr[1].nodeVal);
        arr[1] = arr[ind];
        arr[ind--] = null;
        if(ind > 0) {
            updateIndex(arr[1].nodeVal, 1);
            shiftDown(1);
        }
        return out;
    }
}
/**
 *
 *  Graph representation -: It is an Map(T,Node<T>) of Map(T(neighbour), Integer(Edge's weight))
 *
 */
static class Graph<T> {
    void init(T... t) {
        for(T z: t) {
            nodes.put(z, new Node<>(z));
        }
    }
    public Graph(int s, T... t) {
        size = s;
        nodes = new LinkedHashMap<>(size);
        init(t);
    }
    /**
     *
     * Node class
     *
     */
    static class Node<T> {
        Node(T v) {
            val = v;
        }
        T val;
        //Set<Edge> edges = new HashSet<>();
        Map<T, Integer> edges = new HashMap<>();
    }
    /*static class Edge {
        Edge(int to, int w) {
            target = to;
            wt = w;
        }
        int target;
        int wt;
        }
    }*/
    int size;
    Map<T, Node<T>> nodes;
    void addEdge(T from, T to, int wt) {
        nodes.get(from).edges.put(to, wt);
    }
}
/**
 *
 * @param graph
 * @param from
 * @param heapMap
 * @param <T>
 *
 * Performs initialisation of all the nodes from the start vertex
 *
 */
    private static <T> void init(Graph<T> graph, T from, HeapMap<T> heapMap) {
    Graph.Node<T> fromNode = graph.nodes.get(from);
    graph.nodes.forEach((k,v)-> {
            if(from != k) {
                heapMap.insert(k, fromNode.edges.getOrDefault(k, Integer.MAX_VALUE));
            }
        });
    }

static class NodePathMinWeight<T> {
    NodePathMinWeight(T n, List<T> p, int c) {
        node = n;
        path = p;
        minCost= c;
    }
    T node;
    List<T> path;
    int minCost;
}
/**
 *
 * @param graph
 * @param from
 * @param <T>
 * @return
 *
 * Repeat the below process for all the vertices-:
 * Greedy way of picking the current shortest distance and updating its neighbors distance via this vertex
 *
 * Since Each Vertex V has E edges, the time Complexity is
 *
 * O(V.logV.E)
 * 1. selecting vertex with shortest distance from source in logV time -> O(logV) via Heap Map Data structure
 * 2. Visiting all E edges of this vertex and updating the path of its neighbors if found less via this this vertex. -> O(E)
 * 3. Doing operation step 1 and step 2 for all the vertices -> O(V)
 *
 */
    static <T> Map<T,NodePathMinWeight<T>> dijkstra(Graph<T> graph, T from) {
    Map<T,NodePathMinWeight<T>> output = new HashMap<>();
    HeapMap<T> heapMap = new HeapMap<>(graph.nodes.size());
    init(graph, from, heapMap);
    Set<T> isNotVisited = new HashSet<>();
    graph.nodes.forEach((k,v) -> isNotVisited.add(k));
    isNotVisited.remove(from);
    while(!isNotVisited.isEmpty()) {
        HeapMap.NodeWeight<T> currNodeWeight = heapMap.remove();
        output.put(currNodeWeight.nodeVal, new NodePathMinWeight<>(currNodeWeight.nodeVal, currNodeWeight.path, currNodeWeight.weight));
        //mark visited
        isNotVisited.remove(currNodeWeight.nodeVal);
        //neighbors
        Map<T, Integer> neighbors = graph.nodes.get(currNodeWeight.nodeVal).edges;
        neighbors.forEach((k,v) -> {
            int ind = heapMap.index(k);
            HeapMap.NodeWeight<T> neighbor = heapMap.node(ind);
            int neighborDist = neighbor.weight;
            int currentDistance = currNodeWeight.weight;
            if(currentDistance + v < neighborDist) {
                //update
                neighbor.path = new ArrayList<>(currNodeWeight.path);
                heapMap.update(neighbor.nodeVal, currentDistance + v);
            }
        });
    }
    return output;
}
public static void main(String[] args) {
    Graph<Integer> graph = new Graph<>(6,1,2,3,4,5,6);
    graph.addEdge(1,2,2);
    graph.addEdge(1,3,4);
    graph.addEdge(2,3,1);
    graph.addEdge(2,4,7);
    graph.addEdge(3,5,3);
    graph.addEdge(5,6,5);
    graph.addEdge(4,6,1);
    graph.addEdge(5,4,2);
    Integer source = 1;
    Map<Integer,NodePathMinWeight<Integer>> map = dijkstra(graph,source);
    map.forEach((k,v) -> {
        v.path.add(0,source);
        System.out.println("source vertex :["+source+"] to vertex :["+k+"] cost:"+v.minCost+" shortest path :"+v.path);
    });
}

}

输出-:

源顶点:[1]到顶点:[2]成本:2最短路径:[1,2]

源顶点:[1]到顶点:[3]成本:3最短路径:[1,2,3]

源顶点:[1]到顶点:[4]成本:8最短路径:[1,2,3,5,4]

源顶点:[1]到顶点:[5]成本:6最短路径:[1,2,3,5]

源顶点:[1]到顶点:[6]成本:9最短路径:[1,2,3,5,4,6]

最新更新