我在麻省理工学院学习这门关于算法的课程。在第一堂课上,教授提出了以下问题:-
2D阵列中的峰值是这样一个值,即它的所有4个邻居都小于或等于它,即
a[i][j]
为局部最大值,
a[i+1][j] <= a[i][j]
&& a[i-1][j] <= a[i][j]
&& a[i][j+1] <= a[i][j]
&& a[i+1][j-1] <= a[i][j]
现在给定一个NxN 2D阵列,在阵列中找到一个峰值。
这个问题可以在O(N^2)
时间内通过迭代所有元素并返回峰值来轻松解决。
然而,它可以优化为在O(NlogN)
时间内通过使用这里解释的分治解决方案来求解。
但他们说,存在一种O(N)
时间算法来解决这个问题。请建议我们如何在O(N)
时间内解决这个问题。
PS(对于那些了解python的人)课程工作人员在这里解释了一种方法(问题1-5。Peak Finding Proof),并在他们的问题集中提供了一些python代码。但所解释的方法完全不明显,也很难解读。python代码同样令人困惑。因此,我复制了下面代码的主要部分,供那些了解python并能从代码中判断出使用了什么算法的人使用。
def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None):
# if it's empty, we're done
if problem.numRow <= 0 or problem.numCol <= 0:
return None
subproblems = []
divider = []
if rowSplit:
# the recursive subproblem will involve half the number of rows
mid = problem.numRow // 2
# information about the two subproblems
(subStartR1, subNumR1) = (0, mid)
(subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1))
(subStartC, subNumC) = (0, problem.numCol)
subproblems.append((subStartR1, subStartC, subNumR1, subNumC))
subproblems.append((subStartR2, subStartC, subNumR2, subNumC))
# get a list of all locations in the dividing column
divider = crossProduct([mid], range(problem.numCol))
else:
# the recursive subproblem will involve half the number of columns
mid = problem.numCol // 2
# information about the two subproblems
(subStartR, subNumR) = (0, problem.numRow)
(subStartC1, subNumC1) = (0, mid)
(subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1))
subproblems.append((subStartR, subStartC1, subNumR, subNumC1))
subproblems.append((subStartR, subStartC2, subNumR, subNumC2))
# get a list of all locations in the dividing column
divider = crossProduct(range(problem.numRow), [mid])
# find the maximum in the dividing row or column
bestLoc = problem.getMaximum(divider, trace)
neighbor = problem.getBetterNeighbor(bestLoc, trace)
# update the best we've seen so far based on this new maximum
if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
bestSeen = neighbor
if not trace is None: trace.setBestSeen(bestSeen)
# return when we know we've found a peak
if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
if not trace is None: trace.foundPeak(bestLoc)
return bestLoc
# figure out which subproblem contains the largest number we've seen so
# far, and recurse, alternating between splitting on rows and splitting
# on columns
sub = problem.getSubproblemContaining(subproblems, bestSeen)
newBest = sub.getLocationInSelf(problem, bestSeen)
if not trace is None: trace.setProblemDimensions(sub)
result = algorithm4(sub, newBest, not rowSplit, trace)
return problem.getLocationInSelf(sub, result)
#Helper Method
def crossProduct(list1, list2):
"""
Returns all pairs with one item from the first list and one item from
the second list. (Cartesian product of the two lists.)
The code is equivalent to the following list comprehension:
return [(a, b) for a in list1 for b in list2]
but for easier reading and analysis, we have included more explicit code.
"""
answer = []
for a in list1:
for b in list2:
answer.append ((a, b))
return answer
- 让我们假设数组的宽度大于高度,否则我们将向另一个方向分裂
- 将阵列拆分为三部分:中央列、左侧和右侧
- 穿过中心柱和两个相邻柱,寻找最大值。
- 如果它在中央立柱上,这就是我们的顶峰
- 如果它在左侧,则在子阵列
left_side + central_column
上运行此算法 - 如果它在右侧,则在子阵列
right_side + central_column
上运行此算法
为什么这样做:
对于最大元素位于中心列的情况,这是显而易见的。如果不是,我们可以从最大值开始增加元素,并且肯定不会越过中心行,因此在相应的一半中肯定会存在峰值。
为什么这是O(n):
步骤#3在每两个算法步骤上进行小于或等于CCD_ 8的迭代并且CCD_。由此得到CCD_ 10,即CCD_。重要细节:我们按最大方向分开。对于方形阵列,这意味着分割方向将是交替的。这与您链接到的PDF中的上次尝试不同。
注意:我不确定它是否与你给出的代码中的算法完全匹配,这可能是一种不同的方法,也可能不是。
查看a(n):
计算步骤在图片中
查看算法实现:
1) 从1a)或1b)开始
1a)设置左半部分、分隔器、右半部分。
1b)设置上半部、分隔器、下半部。
2) 在除法器上查找全局最大值。θn]
3) 查找其邻居的值。并将有史以来访问过的最大节点记录为最佳浏览节点。θ1]
# update the best we've seen so far based on this new maximum
if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
bestSeen = neighbor
if not trace is None: trace.setBestSeen(bestSeen)
4) 检查全局最大值是否大于bestSeen及其邻居。θ1]
//步骤4是为什么这个算法工作的主要关键
# return when we know we've found a peak
if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
if not trace is None: trace.foundPeak(bestLoc)
return bestLoc
5) 如果4)为True,则将全局最大值返回为2-D峰值。
否则,如果这次是1a),选择BestSeen的一半,返回步骤1b)
否则,选择BestSeen的一半,返回步骤1a)
为了直观地了解这种算法的工作原理,它就像抓住最大值的一面,不断减少边界,最终得到最佳视觉值。
#可视化模拟
圆形1
圆形2
round3
圆形4
圆形5
圆形6
最后
对于这个10*10矩阵,我们只使用了6个步骤来搜索2-D峰值,这非常令人信服,它确实是θn
通过Falcon
以下是实现@maxim1000算法的Java代码。以下代码在2D阵列中查找线性时间中的峰值。
import java.util.*;
class Ideone{
public static void main (String[] args) throws java.lang.Exception{
new Ideone().run();
}
int N , M ;
void run(){
N = 1000;
M = 100;
// arr is a random NxM array
int[][] arr = randomArray();
long start = System.currentTimeMillis();
// for(int i=0; i<N; i++){ // TO print the array.
// System. out.println(Arrays.toString(arr[i]));
// }
System.out.println(findPeakLinearTime(arr));
long end = System.currentTimeMillis();
System.out.println("time taken : " + (end-start));
}
int findPeakLinearTime(int[][] arr){
int rows = arr.length;
int cols = arr[0].length;
return kthLinearColumn(arr, 0, cols-1, 0, rows-1);
}
// helper function that splits on the middle Column
int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
if(loCol==hiCol){
int max = arr[loRow][loCol];
int foundRow = loRow;
for(int row = loRow; row<=hiRow; row++){
if(max < arr[row][loCol]){
max = arr[row][loCol];
foundRow = row;
}
}
if(!correctPeak(arr, foundRow, loCol)){
System.out.println("THIS PEAK IS WRONG");
}
return max;
}
int midCol = (loCol+hiCol)/2;
int max = arr[loRow][loCol];
for(int row=loRow; row<=hiRow; row++){
max = Math.max(max, arr[row][midCol]);
}
boolean centralMax = true;
boolean rightMax = false;
boolean leftMax = false;
if(midCol-1 >= 0){
for(int row = loRow; row<=hiRow; row++){
if(arr[row][midCol-1] > max){
max = arr[row][midCol-1];
centralMax = false;
leftMax = true;
}
}
}
if(midCol+1 < M){
for(int row=loRow; row<=hiRow; row++){
if(arr[row][midCol+1] > max){
max = arr[row][midCol+1];
centralMax = false;
leftMax = false;
rightMax = true;
}
}
}
if(centralMax) return max;
if(rightMax) return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow);
if(leftMax) return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow);
throw new RuntimeException("INCORRECT CODE");
}
// helper function that splits on the middle
int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
if(loRow==hiRow){
int ans = arr[loCol][loRow];
int foundCol = loCol;
for(int col=loCol; col<=hiCol; col++){
if(arr[loRow][col] > ans){
ans = arr[loRow][col];
foundCol = col;
}
}
if(!correctPeak(arr, loRow, foundCol)){
System.out.println("THIS PEAK IS WRONG");
}
return ans;
}
boolean centralMax = true;
boolean upperMax = false;
boolean lowerMax = false;
int midRow = (loRow+hiRow)/2;
int max = arr[midRow][loCol];
for(int col=loCol; col<=hiCol; col++){
max = Math.max(max, arr[midRow][col]);
}
if(midRow-1>=0){
for(int col=loCol; col<=hiCol; col++){
if(arr[midRow-1][col] > max){
max = arr[midRow-1][col];
upperMax = true;
centralMax = false;
}
}
}
if(midRow+1<N){
for(int col=loCol; col<=hiCol; col++){
if(arr[midRow+1][col] > max){
max = arr[midRow+1][col];
lowerMax = true;
centralMax = false;
upperMax = false;
}
}
}
if(centralMax) return max;
if(lowerMax) return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow);
if(upperMax) return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1);
throw new RuntimeException("Incorrect code");
}
int[][] randomArray(){
int[][] arr = new int[N][M];
for(int i=0; i<N; i++)
for(int j=0; j<M; j++)
arr[i][j] = (int)(Math.random()*1000000000);
return arr;
}
boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not
if(row-1>=0 && arr[row-1][col]>arr[row][col]) return false;
if(row+1<N && arr[row+1][col]>arr[row][col]) return false;
if(col-1>=0 && arr[row][col-1]>arr[row][col]) return false;
if(col+1<M && arr[row][col+1]>arr[row][col]) return false;
return true;
}
}