跳到主要内容

前 K 个高频元素(347)

题目描述

给定一个非空的整数数组,返回其中出现频率前 k高的元素。

样例

Input: nums = [1,1,1,2,2,3], k = 2
Output: [1,2]

题解

哈希表统计频率 + 排序筛选出前 k 高的元素

时间复杂度主要决定于排序的复杂度,$O(mlog_m)$ ,m 为非重复元素个数

排序部分可以通过堆排(collections.Counter.most_common源码实现方式),也可以通过快排找前K个最大元素。当数据的范围不是很大的时候,使用 桶排 也是一个比较合适的方法,$O(p)$ p的大小由频率决定。

  • 堆排
  • 快排变形
  • 桶排

Python 样例

通过 collections 接口调用,most_common 的源码是使用堆排实现

import collections
class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
return [ i for i, j in collections.Counter(nums).most_common(k)]

通过 heapq 实现堆排

import heapq
class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
counts = {}
for num in nums:
counts[num] = counts[num] + 1 if num in counts else 0

heap = []
for key, value in counts.items():
heapq.heappush(heap, (-value, key))
return [heapq.heappop(heap)[1] for _ in range(k)]

通过 变形快排 实现前K个元素

from collections import defaultdict
def quick_sort(nums, l, r, k):
if l >= r:
return
low, high = l + 1, r
pivot = nums[l]
while True:
while low < r and nums[low] <= pivot:
low += 1
while high > l and nums[high] >= pivot:
high -= 1
if low >= high:
break
nums[low], nums[high] = nums[high], nums[low]
nums[l], nums[high] = nums[high], nums[l]
# high 是最终的位置
if high == k:
return
elif high < k:
quick_sort(nums, high + 1, r, k)
else:
quick_sort(nums, l, high - 1, k)

class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
counts = defaultdict(int) # 有默认值的 dict
for num in nums:
counts[num] += 1 counts[num] + 1

countsList = [(-v, k) for k, v in counts.items()]
quick_sort(countsList, 0, len(countsList) - 1, k)
return [num[1] for num in countsList[:k]]

桶排 实现

class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
counts = {}
max_counts = 0
for num in nums:
counts[num] = counts[num] + 1 if num in counts else 1
max_counts = max(max_counts, counts[num])

# bucket 排
bucket = [[] for _ in range(max_counts + 1)]
for key, value in counts.items():
bucket[value].append(key)

# 结果
ans = []
for i in range(max_counts, -1, -1):
for count in bucket[i]:
ans.append(count)
if len(ans) == k:
return ans