每日一题:最大堆的实现

时间:2022-07-23
本文章向大家介绍每日一题:最大堆的实现,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

很久没有做题目了,今天学习下最大堆和最小堆这种数据结构。主要涉及知识点: 1、如何构建一个最大(小)堆 2、如何获取最大(小)元素 3、实现获取无序数组中第k大的数字,对应leetcode:https://leetcode.com/problems/kth-largest-element-in-an-array/

coding…

文中均以最大堆为例,最小堆的原理类似

什么是最大堆

定义很简单: 1、它是一棵二叉树,并且是一棵完成二叉树 2、各个子树的根结点都比孩子结点要大,所以整棵树的根结点即为所有数中最大的那个数

堆的构建

这里我们采用数组来实现一个最大堆。为什么会选用数组呢,主要原因如下图所示:

用数组构建最大堆的构建两种构建方式,一种是循环插入,即一个一个插入,每次插入后的结点都保持最大堆的形式;而另外一种则是先把数据按数据顺序插入,然后从第一个叶子结点开始往上调整。

我们先来看第一种构建方式,分如下几步: 1、将当前插入的元素直接放至数组的末尾 2、从插入值开始往上找,如果父结点值比当前数值小,则交换,直到找到比他大的或者没有结点可找之后结束

实现代码如下:

class MaxHeap():
    """
    最大堆
    """
    def __init__(self):
        self._count = 0
        self._data = [0]

    def shift_up(self, index):
        while index > 1 and self._data[index // 2] < self._data[index]:
            self._data[index // 2], self._data[index] = self._data[index], self._data[index // 2]
            index //= 2

    def insert(self, value):
        self._data.append(value)
        self._count += 1
        self.shift_up(self._count)

    def print(self):
        print(self._data[1:])

if __name__ == "__main__":
    heap = MaxHeap()
    nums = [10, 20, 9, 4, 5, 30]
    for num in nums:
        heap.insert(num)

    heap.print()  # [30, 10, 20, 4, 5, 9]

接着我们来看第二种构建方式,实现步骤如下: 1、直接将整个数据填入数组中 2、从第一个非叶结点开始,向上走,每次与自己的左、右结点比较,调整位置,走到调整到根结点为止

实现代码如下:

class MaxHeap():
    """
    最大堆
    """

    def __init__(self, nums):
        self._data = [0]
        self._count = 0

        for num in nums:
            self._data.append(num)
            self._count += 1

        for i in range(self._count // 2, 0, -1):
            self.shift_down(i)

    def shift_down(self, index):
        while index * 2 <= self._count:
            m = index * 2
            if m + 1 <= self._count and self._data[m + 1] > self._data[m]:
                m = index * 2 + 1

            if self._data[index] > self._data[m]:
                break

            self._data[index], self._data[m] = self._data[m], self._data[index]
            index = m

    def print(self):
        print(self._data[1:])


if __name__ == "__main__":
    nums = [10, 20, 9, 4, 5, 30]
    heap = MaxHeap(nums)
    heap.print()

两种实现方式复杂度不一样,第一种复杂度为 nO(logN),第二种为 O(n),具体分析原因这里不做解释,特别是第二种,需要通过一定的数学归纳方法来得到。可以暂时记忆为主。可以看到第二种的时间复杂度是要优于第一种的。

获取最大值

获取最大值,同样分两个步骤: 1、取出根结点 2、将最后一个结点与根结点交换,然后从根结点开始,与左右结点比较,直到符合条件

实现代码如下:

class MaxHeap():
    """
    最大堆
    """
    def __init__(self):
        self._count = 0
        self._data = [0]

    def shift_up(self, index):
        while index > 1 and self._data[index // 2] < self._data[index]:
            self._data[index // 2], self._data[index] = self._data[index], self._data[index // 2]
            index //= 2

    def insert(self, value):
        self._data.append(value)
        self._count += 1
        self.shift_up(self._count)

    def print(self):
        print(self._data[1:])

    def extract_max(self):
        cur_max = self._data[1]
        self._data[1], self._data[self._count] = self._data[self._count], self._data[1]
        self._count -= 1
        self.shift_down(1)
        return cur_max

    def shift_down(self, index):
        while index * 2 <= self._count:
            m = index * 2
            if m + 1 <= self._count and self._data[m + 1] > self._data[m]:
                m = index * 2 + 1 

            if self._data[index] > self._data[m]:
                break

            self._data[index], self._data[m] = self._data[m], self._data[index]
            index = m

if __name__ == "__main__":
    heap = MaxHeap()
    nums = [10, 20, 9, 4, 5, 30]
    for num in nums:
        heap.insert(num)

    heap.print()

    print(heap.extract_max())

实践

像我们平常遇到的查找第 n 项最大值就可以用最大堆来实现,如果用上面说的第二种构建方式,时间复杂度可优化为 O(n)。

这里直接上代码:

# https://leetcode.com/problems/kth-largest-element-in-an-array/

class MaxHeap():
    """
    最大堆
    """
    def __init__(self):
        self._count = 0
        self._data = [0]

    def insert_arr(self, nums):
        for num in nums:
            self._data.append(num)
            self._count += 1

        for i in range(self._count // 2, 0, -1):
            self.shift_down(i)


    def shift_up(self, index):
        while index > 1 and self._data[index // 2] < self._data[index]:
            self._data[index // 2], self._data[index] = self._data[index], self._data[index // 2]
            index //= 2

    def shift_down(self, index):
        while index * 2 <= self._count:
            m = index * 2
            if m + 1 <= self._count and self._data[m + 1] > self._data[m]:
                m = index * 2 + 1 

            if self._data[index] > self._data[m]:
                break

            self._data[index], self._data[m] = self._data[m], self._data[index]
            index = m

    def insert(self, value):
        self._data.append(value)
        self._count += 1
        self.shift_up(self._count)


    def extract_max(self):
        cur_max = self._data[1]
        self._data[1], self._data[self._count] = self._data[self._count], self._data[1]
        self._count -= 1
        self.shift_down(1)
        return cur_max

class Solution(object):
    """
    实现思路:使用堆
    """
    def findKthLargest(self, nums, k):
        """
        :type nums: List[int]
        :type k: int
        :rtype: int
        """
        max_head = MaxHeap()
        max_head.insert_arr(nums)

        result = None
        for i in range(k):
            result = max_head.extract_max()
        return result

if __name__ == "__main__":
    s = Solution()
    print(s.findKthLargest([3,2,3,1,2,4,5,5,6], 4))