plaintext

N-wise Iteration in Python

A few hours back I stumbled into a problem where I had to perform a lookahead of n elements in a list to do some calculations. The first thought: Just take the current index and get all elements until i+n. I started writing..

for i in range(len(iterable)):
---- SNAP ----

Stop. This is awfully unpythonic. There has to be a better way! Browsing the itertools recipes I found the pairwise function:

def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

Perfect! Now I just have to adjust it to work with n iterators where the first iterator is at i, the second at i+1, etc.

def nwise(iterable, n=1):
    iterators = tee(iterable, n)
    for i in range(n):
        for _ in range(i):
            next(iterators[i], None)
    return zip(*iterators)

There we go. Even though it works on multiple levels of the iterable, it's still memory-efficient, because generators are awesome. This will provide you with an output like this:

In [4]: l = [1,2,3,4,5,6]

In [5]: list(nwise(l, n=3))
Out[5]: [(1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)]

Note that the list call was just used to empty the generator for printing. Here's a quick one-liner that counts the times a fixed-length (42) sequence in a list sums up to a certain value (1337):

sum([1 for seq in nwise(l, n=42) if sum(seq) == 1337])