杰瑞科技汇

Python模块itertools有哪些核心功能?

什么是 itertools

itertools 是 Python 的一个内置模块,它提供了一系列用于创建和操作迭代器的函数,迭代器是一种对象,它可以逐个地产生值,而不会一次性在内存中生成所有值,这使得 itertools 在处理大型数据集或无限序列时特别高效,因为它惰性求值,只在需要时才计算下一个值。

Python模块itertools有哪些核心功能?-图1
(图片来源网络,侵删)

你可以把 itertools 想象成一个“乐高积木箱”,里面装满了各种可以拼接、组合、变换迭代器的工具,掌握它,能让你的代码更简洁、更高效、更具 Pythonic 风格。


itertools 的核心功能

itertools 的函数大致可以分为三类:

  1. 无限迭代器: 可以无限地生成值。
  2. 有限迭代器: 基于一个现有的可迭代对象,生成有限的新序列。
  3. 组合迭代器: 将多个可迭代对象以特定方式组合。

无限迭代器

这类函数会一直运行下去,通常需要配合 break 或其他限制条件来使用。

itertools.count(start=0, step=1)

start 开始,以 step 为步长无限递增。

Python模块itertools有哪些核心功能?-图2
(图片来源网络,侵删)
import itertools
# 从 10 开始,每次加 5
counter = itertools.count(start=10, step=5)
# 使用 islice 来限制输出,否则会无限循环
for i in itertools.islice(counter, 5):
    print(i)  # 输出: 10 15 20 25 30

itertools.cycle(iterable)

将传入的可迭代对象的元素无限循环。

import itertools
# 循环打印 'ABCD'
letters = itertools.cycle('ABCD')
for i in itertools.islice(letters, 8):
    print(i)  # 输出: A B C D A B C D

itertools.repeat(object, times=None)

无限重复一个对象,如果指定 times,则重复指定次数。

import itertools
# 无限重复 5
repeater = itertools.repeat(5)
print(next(repeater))  # 输出: 5
print(next(repeater))  # 输出: 5
# 重复 'Hi' 3 次
hi_repeater = itertools.repeat('Hi', 3)
for item in hi_repeater:
    print(item)  # 输出: Hi Hi Hi

有限迭代器

这类函数基于一个已有的可迭代对象,生成新的、有限的序列。

itertools.chain(*iterables)

将多个可迭代对象“连接”成一个,按顺序逐个产出它们的元素。

Python模块itertools有哪些核心功能?-图3
(图片来源网络,侵删)
import itertools
a = [1, 2, 3]
b = 'ABC'
c = (7, 8, 9)
# 将 a, b, c 连接成一个迭代器
chained = itertools.chain(a, b, c)
for item in chained:
    print(item) 
# 输出: 1 2 3 A B C 7 8 9

itertools.compress(data, selectors)

根据 selectors 中的“真/假”值来筛选 data,只有当 selectors 中对应位置的元素为真时,才会产出 data 中对应的元素。

import itertools
data = 'ABCDEFGH'
selectors = [1, 0, 1, 0, 1, 0, 1, 0] # 1表示保留,0表示舍弃
# 只保留 data 中 selectors 为 1 的位置
filtered = itertools.compress(data, selectors)
for item in filtered:
    print(item) # 输出: A C E G

itertools.groupby(iterable, key=None)

这是一个非常强大的函数,它将连续的、相同的元素分组,返回一个迭代器,产出 (key, group_iterator) 对。

重要提示: groupby 只对连续的、相同的键进行分组,如果需要先排序,请务必先对可迭代对象进行排序。

import itertools
# 示例1:按奇偶性分组
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# 按奇偶性分组的 key 函数
def is_even(n):
    return n % 2 == 0
# 先排序!虽然这里已经是排序好的,但这是一个好习惯
sorted_numbers = sorted(numbers, key=is_even)
for key, group in itertools.groupby(sorted_numbers, key=is_even):
    print(f"Key: {key}, Group: {list(group)}")
# 输出:
# Key: False, Group: [1, 3, 5, 7, 9]
# Key: True, Group: [2, 4, 6, 8, 10]
# 示例2:按首字母分组
data = ['apple', 'ant', 'banana', 'bat', 'cat']
sorted_data = sorted(data) # 必须先按首字母排序
for key, group in itertools.groupby(sorted_data):
    print(f"Key: {key}, Group: {list(group)}")
# 输出:
# Key: apple, Group: ['apple', 'ant']
# Key: banana, Group: ['banana', 'bat']
# Key: cat, Group: ['cat']

itertools.islice(iterable, stop)itertools.islice(iterable, start, stop[, step])

类似于 list 的切片操作,但它作用于任何可迭代对象,返回一个迭代器,而不是一个列表,这对于获取迭代器的一部分而不将其全部加载到内存中非常有用。

import itertools
# 创建一个无限迭代器
count = itertools.count()
# 获取从第 2 个到第 9 个(不包括 9),步长为 2 的元素
sliced = itertools.islice(count, 2, 9, 2)
for item in sliced:
    print(item) # 输出: 2 4 6 8

itertools.starmap(function, iterable)

map 函数的“兄弟”。map 接受 function, arg1, arg2, ...,而 starmap 接受 function 和一个可迭代对象,该可迭代对象的元素应该是可解包的(如元组或列表)。

import itertools
# 计算点的欧几里得距离
points = [(1, 1), (2, 2), (3, 3), (4, 4)]
def distance(p):
    return (p[0]**2 + p[1]**2)**0.5
# 使用 map
# map(distance, points) -> distance((1,1)), distance((2,2)), ...
# 使用 starmap
# starmap 会将每个元组解包成函数的参数
# distance((1, 1)) -> distance(1, 1)
distances = itertools.starmap(distance, points)
for d in distances:
    print(d) # 输出: 1.414... 2.828... 4.242... 5.656...

itertools.takewhile(predicate, iterable)

从可迭代对象的开头开始,产出元素,直到 predicate 函数返回 False 为止。

import itertools
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# 产出所有小于 5 的数字
taken = itertools.takewhile(lambda x: x < 5, numbers)
for item in taken:
    print(item) # 输出: 1 2 3 4

itertools.dropwhile(predicate, iterable)

takewhile 相反,它会丢弃开头的元素,直到 predicate 函数返回 False,然后产出剩下的所有元素。

import itertools
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# 丢弃所有小于 5 的数字,然后输出剩下的
dropped = itertools.dropwhile(lambda x: x < 5, numbers)
for item in dropped:
    print(item) # 输出: 5 6 7 8 9 10

组合迭代器

这类函数用于计算输入可迭代对象元素的组合、排列等。

itertools.product(*iterables, repeat=1)

计算输入可迭代对象的笛卡尔积,相当于嵌套的 for 循环。

import itertools
# 两个集合的笛卡尔积
a = [1, 2]
b = ['a', 'b']
print(list(itertools.product(a, b)))
# 输出: [(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b')]
# 掷两个骰子的所有可能结果
dice = list(itertools.product(range(1, 7), repeat=2))
print(len(dice)) # 输出: 36

itertools.permutations(iterable, r=None)

计算输入可迭代对象中所有长度为 r排列,顺序不同则视为不同的排列。

import itertools
letters = 'ABC'
# 长度为 2 的所有排列
print(list(itertools.permutations(letters, 2)))
# 输出: [('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'C'), ('C', 'A'), ('C', 'B')]

itertools.combinations(iterable, r)

计算输入可迭代对象中所有长度为 r组合,顺序不同但元素相同视为同一个组合。

import itertools
letters = 'ABC'
# 长度为 2 的所有组合
print(list(itertools.combinations(letters, 2)))
# 输出: [('A', 'B'), ('A', 'C'), ('B', 'C')]

itertools.combinations_with_replacement(iterable, r)

计算输入可迭代对象中所有长度为 r可重复组合,元素可以重复选择。

import itertools
letters = 'ABC'
# 长度为 2 的所有可重复组合
print(list(itertools.combinations_with_replacement(letters, 2)))
# 输出: [('A', 'A'), ('A', 'B'), ('A', 'C'), ('B', 'B'), ('B', 'C'), ('C', 'C')]

实际应用场景

  1. 处理大型日志文件:逐行读取日志文件,而不是一次性加载到内存。

    with open('huge_log_file.log') as f:
        # 只处理包含 'ERROR' 的前 100 行
        error_lines = itertools.islice(
            itertools.takewhile(lambda line: line_count < 100, f),
            None
        )
        for line in error_lines:
            if 'ERROR' in line:
                process(line)
  2. 数据分页:从数据库查询返回一个迭代器,然后使用 islice 获取特定页的数据。

    def get_page(page_size, page_number):
        all_items = db.get_all_items() # 假设返回一个迭代器
        start = page_size * (page_number - 1)
        stop = start + page_size
        return list(itertools.islice(all_items, start, stop))
  3. 生成测试数据:使用 product 生成所有可能的输入组合进行测试。

    users = ['admin', 'guest']
    permissions = ['read', 'write', 'execute']
    for user, perm in itertools.product(users, permissions):
        print(f"Testing {user} with {perm} permission...")
  4. 统计分组:使用 groupby 对数据进行分组统计。

    sales = [('North', 100), ('South', 200), ('North', 150), ('East', 50)]
    # 先按地区排序
    sorted_sales = sorted(sales, key=lambda x: x[0])
    for region, sales_data in itertools.groupby(sorted_sales, key=lambda x: x[0]):
        total_sales = sum(sale[1] for sale in sales_data)
        print(f"Region: {region}, Total Sales: {total_sales}")

函数类别 核心函数 描述
无限迭代器 count, cycle, repeat 产生无限序列
有限迭代器 chain, compress, groupby, islice, starmap, takewhile, dropwhile 基于现有序列创建、筛选、切片或分组
组合迭代器 product, permutations, combinations, combinations_with_replacement 计算排列、组合和笛卡尔积

itertools 是 Python 标准库中的一颗明珠,它提供的工具函数不仅高效,而且能让你的代码逻辑更加清晰和优雅,当你需要处理序列、生成器或任何可迭代对象时,itertools 应该是你的首选工具之一。

分享:
扫描分享到社交APP
上一篇
下一篇