根據這道題總結一下快速排序和堆排序,再根據這兩種方法寫這道題。
給定整數數組 nums
和整數 k
,請返回數組中第 k
個最大的元素。
請注意,你需要找的是數組排序后的第 k
個最大的元素,而不是第 k
個不同的元素。
你必須設計并實現時間復雜度為 O(n)
的算法解決此問題。
示例 1:
輸入: [3,2,1,5,6,4]
, k = 2
輸出: 5
示例 2:
輸入: [3,2,3,1,2,4,5,5,6]
, k = 4
輸出: 4
提示:
1 <= k <= nums.length <= 105
-104 <= nums[i] <= 104
我們首先給出快速排序的代碼,快速排序的思路是先選取一個基準值,然后把小于基準值的放到基準值左邊,把大于基準值的放到基準值右邊,這樣就會變成三部分(基準值左邊部分、基準值、基準值右邊部分),對基準值左右再遞歸進行這個步驟。代碼分三部分:快速排序輔助分區部分、排序部分和主函數,分區部分就是把比基準值小的放左邊,比基準值大的放右邊,然后把基準值放中間,排序部分就是遞歸排序。
#include <iostream>
#include <vector>
#include <utility> // for std::swap// 快速排序的輔助函數,進行分區
int partition(std::vector<int> &nums, int low, int high) {// 選擇最左側的元素作為基準值(pivot)int pivot = nums[low];int i = low + 1; // i指針用來記錄比基準值小的區域的最后一個元素的位置int j = high; // j指針用來記錄比基準值大的區域的第一個元素的位置// 循環進行分區操作while(true) {// 從左向右找,找到大于等于基準值的元素while (nums[i] < pivot) {i++;}// 從右向左找,找到小于等于基準值的元素while (nums[j] > pivot) {j--;}if (i < j) {std::swap(nums[i], nums[j]);} else {// 完成分區,左邊全是小于等于基準值,右邊全是大于等于基準值break;}}// 交換基準值到分區的中間std::swap(nums[low], nums[j]);// 返回基準值的最終位置return i;
}// 快速排序的遞歸函數
void quickSort(std::vector<int> &nums, int low, int high) {if (low < high) {// 分區操作int pivotIndex = partition(nums, low, high);// 對基準值左邊的子序列進行快速排序quickSort(nums, low, pivotIndex - 1);// 對基準值右邊的子序列進行快速排序quickSort(nums, pivotIndex + 1, high);}
}int main() {std::vector<int> nums = {10, 7, 8, 9, 1, 5};int n = nums.size();quickSort(nums, 0, n - 1);for (int num : nums) {std::cout << num << " ";}return 0;
}
運行結果(每一步分區的過程)為:
6 7 8 9 1 5 3 3 6 1 10
6 1 3 3 1 5 6 9 8 7 10
5 1 3 3 1 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 7 8 9 10
1 1 3 3 5 6 6 7 8 9 10
1 1 3 3 5 6 6 7 8 9 10
快速排序的時間復雜度是O(nlogn)
。
基于快速排序可以寫出做這道題的快速選擇方法的代碼,與快速排序一樣,需要先分區,之后確定了基準值最終所在的位置,然后不需要進行排序操作,只需要知道第k大的元素是在基準值左邊還是右邊,然后在那個分區找就可以了,也是遞歸來查找,這個就是在快排的過程中直接找到了,所以不需要進行完整的快排,因此復雜度變低:
#include <iostream>
#include <vector>
#include <utility> // for std::swap// 快速排序的輔助函數,進行分區
int partition(std::vector<int> &nums, int low, int high) {// 選擇最左側的元素作為基準值(pivot)int pivot = nums[low];int i = low + 1; // i指針用來記錄比基準值小的區域的最后一個元素的位置int j = high; // j指針用來記錄比基準值大的區域的第一個元素的位置// 循環進行分區操作while(true) {// 從左向右找,找到大于等于基準值的元素while (nums[i] < pivot) {i++;}// 從右向左找,找到小于等于基準值的元素while (nums[j] > pivot) {j--;}if (i < j) {std::swap(nums[i], nums[j]);} else {// 完成分區,左邊全是小于等于基準值,右邊全是大于等于基準值break;}}// 交換基準值到分區的中間std::swap(nums[low], nums[j]);// 返回基準值的最終位置return j;
}// 快速排序的遞歸函數
int quickSelect(std::vector<int> &nums, int low, int high, int kIndex) {if (low == high) {// 當子數組只有一個元素時,返回該元素return nums[low];}int pivotIndex = partition(nums, low, high);if (kIndex <= pivotIndex) {// 第k大的元素索引在左側子數組中return quickSelect(nums, low, pivotIndex, kIndex);} else {// 第k大的元素索引在右側子數組中return quickSelect(nums, pivotIndex + 1, high, kIndex);}
}int main() {std::vector<int> nums = {10, 5, 3, 2, 1, 6, 8, 7};int n = nums.size();int k = 3;// 第k大的元素的索引是k-1int kIndex = k - 1;int ans = quickSelect(nums, 0, n - 1, n - 1 - kIndex);std::cout << "The ans is " << ans << std::endl;return 0;
}
注意,當求第k
大的元素時,傳入的是索引k-1
,當求第k
小的元素(第n-k+1
大)時,傳入索引n-k
(即n-1-kIndex
)。這個方法時間復雜度是O(n)
。
下面來總結一下堆排序和這道題,我們給出堆排序的代碼:
#include <iostream>
#include <vector>
#include <algorithm> // for std::swap// 自上向下調整堆,保證堆的性質
void heapify(std::vector<int> &nums, int n, int i) {int largest = i; // 初始時假設當前節點為最大值int left = 2 * i + 1; // 左子節點int right = 2 * i + 2; // 右子節點// 如果左子節點存在且大于當前節點,更新最大值節點if (left < n && nums[left] > nums[largest]) {largest = left;}// 如果右子節點存在且大于當前節點,更新最大值節點if (right < n && nums[right] > nums[largest]) {largest = right;}// 如果最大值節點發生了變化,交換當前節點和最大值節點的值,并繼續調整if (largest != i) {std::swap(nums[i], nums[largest]);heapify(nums, n, largest);}
}// 堆排序
void heapSort(std::vector<int> &nums) {int n = nums.size();// 從最后一個非葉子節點開始建堆,即從 (n/2 - 1) 節點開始for (int i = n / 2 - 1; i >= 0; i--) {heapify(nums, n, i);}// 從最后一個元素開始,交換元素并進行調整堆操作for (int i = n - 1; i > 0; i--) {std::swap(nums[0], nums[i]); // 將當前堆的最大值放到數組末尾heapify(nums, i, 0); // 調整堆,新的堆大小為 i}
}int main() {std::vector<int> nums = {16, 10, 8, 7, 2, 3, 4, 1, 9, 14};heapSort(nums);std::cout << "Sorted array: ";for (int num : nums) {std::cout << num << " ";}return 0;
}
堆排序有三個重要部分:維護堆的性質,建堆,排序。以大根堆為例,這是一顆完全二叉樹,父節點的值大于子節點的值,下標為i
的節點的父節點下標是(i - 1) / 2
(整數除法),下標為i
的節點的左孩子下標是i * 2 + 1
,右孩子下標是i * 2 + 2
,因此,假如有n
個元素,那么堆的最后一個非葉子節點的下標是n / 2 - 1
。
- 維護堆的性質,即為保證父節點值大于子節點值,從上而下調整,比如當前
i
節點不滿足這個性質,那么交換i
節點和它的左右孩子中最大的那個,然后再判斷子節點那里是否滿足堆的性質(之所以需要這樣是因為如果進行了交換,那么子節點那里可能會發生變化,比如3 6 5 2 4
這個情況,首先3
和6
進行了交換,變成了6 3 5 2 4
,那么3 2 4
那個部分(之前是6 2 4
)就需要再次進行交換)。 - 建堆,即從最后一個非葉子節點開始,自下而上維護堆的性質,直到根節點。
- 堆排序,將當前堆的最大值放到數組末尾,然后把它排除出去,再從根向下進行堆的維護,新的堆的大小為
n-1
,重復這個過程,直到只剩一個元素。
運行結果為:
create heap
16 14 8 9 10 3 4 1 7 2
sort heap for 9 nums 14 10 8 9 2 3 4 1 7 16
sort heap for 8 nums 10 9 8 7 2 3 4 1 14 16
sort heap for 7 nums 9 7 8 1 2 3 4 10 14 16
sort heap for 6 nums 8 7 4 1 2 3 9 10 14 16
sort heap for 5 nums 7 3 4 1 2 8 9 10 14 16
sort heap for 4 nums 4 3 2 1 7 8 9 10 14 16
sort heap for 3 nums 3 1 2 4 7 8 9 10 14 16
sort heap for 2 nums 2 1 3 4 7 8 9 10 14 16
sort heap for 1 nums 1 2 3 4 7 8 9 10 14 16
sorted heap
1 2 3 4 7 8 9 10 14 16
可以看到16, 10, 8, 7, 2, 3, 4, 1, 9, 14
經過建堆過程(自最后一個非葉子節點向上維護堆),變成了16 14 8 9 10 3 4 1 7 2
,然后需要進行堆排序,將16
和14
交換,然后不管16
了,這個時候它是最后一個元素,再從根向下維護堆,得到14 10 8 9 2 3 4 1 7 16
,然后再將14
和7
交換,進行相同的步驟,最后排序成功。
有了堆排序的基礎,我們利用堆排序解決數組中的第K
個最大元素的問題,事實上,在堆排序取最大值的過程中,已經體現出來了,在第一次取16
,這就是第1
大的元素,第二次取14
就是第2
大的元素,那么我們想得到第k
大元素的值,只需要設置堆排序的停止條件為i > n - k
,然后這時候的nums[0]
(即根節點值)為第k
大的元素。如果我們想得到第’k’小的元素,那么就取第n-k+1
大的元素。
詳細代碼如下:
#include <iostream>
#include <vector>
#include <algorithm> // for std::swap// 自上向下調整堆,保證堆的性質
void heapify(std::vector<int> &nums, int n, int i) {int largest = i; // 初始時假設當前節點為最大值int left = 2 * i + 1; // 左子節點int right = 2 * i + 2; // 右子節點// 如果左子節點存在且大于當前節點,更新最大值節點if (left < n && nums[left] > nums[largest]) {largest = left;}// 如果右子節點存在且大于當前節點,更新最大值節點if (right < n && nums[right] > nums[largest]) {largest = right;}// 如果最大值節點發生了變化,交換當前節點和最大值節點的值,并繼續調整if (largest != i) {std::swap(nums[i], nums[largest]);heapify(nums, n, largest);}
}// 堆排序取數
int heapSelect(std::vector<int> &nums, int k) {int n = nums.size();// 從最后一個非葉子節點開始建堆,即從 (n/2 - 1) 節點開始for (int i = n / 2 - 1; i >= 0; i--) {heapify(nums, n, i);}std::cout << "create heap" << std::endl;for (int num : nums) {std::cout << num << " ";}std::cout << "\n";// 從最后一個元素開始,交換元素并進行調整堆操作for (int i = n - 1; i > n - k; i--) {std::swap(nums[0], nums[i]); // 將當前堆的最大值放到數組末尾heapify(nums, i, 0); // 調整堆,新的堆大小為 istd::cout << "sort heap for " << i << " nums" << " ";for (int num : nums) {std::cout << num << " ";}std::cout << "\n";}std::cout << "sorted heap" << std::endl; for (int num : nums) {std::cout << num << " ";}return nums[0];
}int main() {std::vector<int> nums = {16, 10, 8, 7, 2, 3, 4, 1, 9, 14};int n = nums.size();int k = 4;int ans = heapSelect(nums, n - k + 1);std::cout << "ans=" << ans << std::endl;return 0;
}
運行結果:
create heap
16 14 8 9 10 3 4 1 7 2
sort heap for 9 nums 14 10 8 9 2 3 4 1 7 16
sort heap for 8 nums 10 9 8 7 2 3 4 1 14 16
sort heap for 7 nums 9 7 8 1 2 3 4 10 14 16
sort heap for 6 nums 8 7 4 1 2 3 9 10 14 16
sort heap for 5 nums 7 3 4 1 2 8 9 10 14 16
sort heap for 4 nums 4 3 2 1 7 8 9 10 14 16
sorted heap
4 3 2 1 7 8 9 10 14 16 ans=4
時間復雜度是O(nlogn)
,建堆的復雜度是O(n)
,刪除堆頂元素的復雜度是O(klogn)
,所以總共的時間復雜度是O(n+klogn)=O(nlogn)
。