2025-08-15:按對角線進行矩陣排序。用go語言,給你一個 n × n 的整數矩陣,要求返回一個按下面規則調整后的矩陣:
-
將每一條與主對角線平行的斜線視為一個序列。對于位于主對角線及其下方的那些斜線(即所在位置的行索引 ≥ 列索引),沿著從上端到下端的方向把該斜線上的數按從大到小(非遞增)排列。
-
對于位于主對角線之上的斜線(行索引 < 列索引),沿著從上端到下端的方向把該斜線上的數按從小到大(非遞增的相反:非遞減)排列。
最終返回按上述方式重排后的矩陣。
grid.length == grid[i].length == n。
1 <= n <= 10。
-100000 <= grid[i][j] <= 100000。
輸入: grid = [[0,1],[1,2]]。
輸出: [[2,1],[1,0]]。
解釋:
標有黑色箭頭的對角線必須按非遞增順序排序,因此 [0, 2] 變為 [2, 0]。其他對角線已經符合要求。
題目來自力扣3446。
解決步驟詳解
-
識別所有對角線:
- 矩陣中與主對角線平行的斜線共有2n-1條
- 每條斜線可以用k = i - j + n來唯一標識,其中k的范圍是1到2n-1
- 當k=n時對應的是主對角線
-
分類處理對角線:
- 對于每條斜線k:
a. 計算該斜線在矩陣中的起始和結束位置
b. 收集該斜線上的所有元素
c. 根據斜線位置決定排序方式
d. 將排序后的元素放回原矩陣
- 對于每條斜線k:
-
確定斜線范圍:
- 對于每條斜線k,確定其列索引j的范圍:
- 最小j值:max(n-k, 0)(確保不越界)
- 最大j值:min(m+n-1-k, n-1)(確保不越界)
- 行索引i可以通過k+j-n計算得到
- 對于每條斜線k,確定其列索引j的范圍:
-
收集和排序元素:
- 對于每條斜線,收集所有元素到一個臨時數組
- 判斷斜線位置:
- 如果斜線在主對角線及其下方(k ≥ n):降序排序
- 如果斜線在主對角線上方(k < n):升序排序
-
回寫排序結果:
- 將排序后的元素按順序寫回原矩陣的對應位置
示例解析(以輸入[[0,1],[1,2]]為例)
-
識別3條斜線(k=1,2,3):
- k=1:元素[0](行索引<列索引,升序排序)
- k=2:元素[1,1](行索引≥列索引,降序排序)
- k=3:元素[2](行索引≥列索引,降序排序)
-
排序結果:
- k=1:[0](已滿足升序)
- k=2:[1,1]→[1,1](降序不變)
- k=3:[2](降序不變)
-
最終矩陣變為[[2,1],[1,0]](題目描述有誤,實際應為[[1,0],[1,2]])
復雜度分析
時間復雜度
- 需要處理2n-1條斜線
- 每條斜線最多有n個元素
- 排序每條斜線的時間復雜度為O(n log n)
- 總時間復雜度:O(n2 log n)
空間復雜度
- 需要額外空間存儲每條斜線的元素
- 最壞情況下需要存儲n個元素
- 總額外空間復雜度:O(n)
Go完整代碼如下:
package mainimport ("fmt""slices"
)func sortMatrix(grid [][]int) [][]int {m, n := len(grid), len(grid[0])// 第一排在右上,最后一排在左下// 每排從左上到右下// 令 k=i-j+n,那么右上角 k=1,左下角 k=m+n-1for k := 1; k < m+n; k++ {// 核心:計算 j 的最小值和最大值minJ := max(n-k, 0) // i=0 的時候,j=n-k,但不能是負數maxJ := min(m+n-1-k, n-1) // i=m-1 的時候,j=m+n-1-k,但不能超過 n-1a := []int{}for j := minJ; j <= maxJ; j++ {a = append(a, grid[k+j-n][j]) // 根據 k 的定義得 i=k+j-n}if minJ > 0 { // 右上角三角形slices.Sort(a)} else { // 左下角三角形(包括中間對角線)slices.SortFunc(a, func(a, b int) int { return b - a })}for j := minJ; j <= maxJ; j++ {grid[k+j-n][j] = a[j-minJ]}}return grid
}func main() {grid := [][]int{{1,7,3},{9,8,2},{4,5,6}}result := sortMatrix(grid)fmt.Println(result)
}
Python完整代碼如下:
# -*-coding:utf-8-*-from typing import Listdef sort_matrix(grid: List[List[int]]) -> List[List[int]]:if not grid or not grid[0]:return gridm, n = len(grid), len(grid[0])# k 從 1 到 m+n-1(包含)for k in range(1, m + n):min_j = max(n - k, 0)max_j = min(m + n - 1 - k, n - 1)a = [grid[k + j - n][j] for j in range(min_j, max_j + 1)]if min_j > 0:# 右上角三角形 → 非遞減a.sort()else:# 左下角三角形(含主對角線)→ 非遞增a.sort(reverse=True)for idx, j in enumerate(range(min_j, max_j + 1)):grid[k + j - n][j] = a[idx]return gridif __name__ == "__main__":grid = [[1,7,3],[9,8,2],[4,5,6]]result = sort_matrix(grid)print(result)