Higher order functions

See Built-in functions.

The map function

map was what they used before they invented list comprehensions. map returns an iterator, which we can easily turn into a list if we want to print the resulting numbers.

tens = [10, 20, 30, 40, 50]

numbers = [ten + 1 for ten in tens]
print(numbers)
tens = [10, 20, 30, 40, 50]

it = map(lambda ten: ten + 1, tens)
numbers = list(it)
print(numbers)
[11, 21, 31, 41, 51]
[11, 21, 31, 41, 51]
tens = [10, 20, 30, 40, 50]
ones = [ 1,  2,  3,  4,  5]

numbers = [ten + one for ten, one in zip(tens, ones)]
print(numbers)
tens = [10, 20, 30, 40, 50]
ones = [ 1,  2,  3,  4,  5]

it = map(lambda ten, one: ten + one, tens, ones)
numbers = list(it)
print(numbers)
[11, 22, 33, 44, 55]
[11, 22, 33, 44, 55]
tens = [10,   20,   30,   40,   50  ]
ones = [ 1,    2,    3,    4,    5  ]
fras = [  .1,   .2,   .3,   .4,   .5]   #fractions

numbers = [ten + one + fra for ten, one, fra in zip(tens, ones, fras)]
print(numbers)
tens = [10,   20,   30,   40,   50  ]
ones = [ 1,    2,    3,    4,    5  ]
fras = [  .1,   .2,   .3,   .4,   .5]   #fractions

it = map(lambda ten, one, fra: ten + one + fra, tens, ones, fras)
numbers = list(it)
print(numbers)
[11.1, 22.2, 33.3, 44.4, 55.5]
[11.1, 22.2, 33.3, 44.4, 55.5]

itertools.starmap

Use map if you have two or more parallel lists. Use itertools.starmap if you have one list with two or more columns. The two situations are very similar.

import itertools
import operator

argumentLists = [
    [10, 1],
    [20, 2],
    [30, 3],
    [40, 4],
    [50, 5]
]

numbers = [operator.add(*argumentList) for argumentList in argumentLists]
print(numbers)

it = itertools.starmap(operator.add, argumentLists)
numbers = list(it)
print(numbers)
[11, 22, 33, 44, 55]
[11, 22, 33, 44, 55]
import itertools

listOfArgs = [
    [ 2, 5],
    [ 3, 2],
    [10, 3]
]

numbers = [pow(*args) for args in listOfArgs]
print(numbers)

it = itertools.starmap(pow, listOfArgs)
numbers = list(it)
print(numbers)
[32, 9, 1000]
[32, 9, 1000]

The filter function

filter was what they used before they invented list comprehensions containing if. The lambda function passed to filter should return True or False. See also itertools.filterfalse.

people = [
    "John Philip Sousa",
    "Johann Sebastian Bach",
    "John Paul Jones",
    "George Herbert Walker Bush"
    "Madonna",
    "Franklin Delano Roosevelt",
    "Adam Smith",
    "Charles Foster Kane"
]

peopleWithMiddleNames = [person for person in people if len(person.split()) >= 3]

for person in peopleWithMiddleNames:
    print(person)
people = [
    "John Philip Sousa",
    "Johann Sebastian Bach",
    "John Paul Jones",
    "George Herbert Walker Bush"
    "Madonna",
    "Franklin Delano Roosevelt",
    "Adam Smith",
    "Charles Foster Kane"
]

peopleWithMiddleNames = filter(lambda person: len(person.split()) >= 3, people)

for person in peopleWithMiddleNames:
    print(person)
John Philip Sousa
Johann Sebastian Bach
John Paul Jones
George Herbert Walker Bush
Franklin Delano Roosevelt
Charles Foster Kane
John Philip Sousa
Johann Sebastian Bach
John Paul Jones
George Herbert Walker Bush
Franklin Delano Roosevelt
Charles Foster Kane

functools.reduce

See The functools module.

Three ways to compute the same sum

numbers = [10, 20, 30, 40, 50]

total = sum(numbers)

print(f"total = {total}")
total = 150

The orange code puts an initial value into total. The yellow code in the for loop changes the value of total over and over again, using each item in the original list. Write the yellow code so that the orange code variable (in this case total) appears in the yellow code and to the left of the = to the left of the yellow code.

numbers = [10, 20, 30, 40, 50]

total = 0
for number in numbers:
    total = total + number   #or total += number

print(f"total = {total}")
total = 150
import functools

numbers = [10, 20, 30, 40, 50]

total = functools.reduce(lambda total, number: total + number, numbers, 0)

print(f"total = {total}")
total = 150

Two shortcuts

1. If the list is not empty, you don’t have to write the 0.

import functools

numbers = [10, 20, 30, 40, 50]

total = functools.reduce(lambda total, number: total + number, numbers)

print(f"total = {total}")
total = 150

2. The above lambda function takes two arguments and returns their sum. The Python Standard Library has a function named operator.add that does the same thing as the lambda function.

import functools
import operator

numbers = [10, 20, 30, 40, 50]

total = functools.reduce(operator.add, numbers)

print(f"total = {total}")
total = 150

Three ways to compute the same product

numbers = [10, 20, 30, 40, 50]

product = 1
for number in numbers:
    product = product * number   #or product *= number

print(f"product = {product:,}")
product = 12,000,000
import functools

numbers = [10, 20, 30, 40, 50]

product = functools.reduce(lambda product, number: product * number, numbers, 1)

print(f"product = {product:,}")
product = 12,000,000
import functools
import operator

numbers = [10, 20, 30, 40, 50]

product = functools.reduce(operator.mul, numbers, 1)

print(f"product = {product:,}")
product = 12,000,000

Four ways to compute the same concatenation

syllables = ["in", "dig", "na", "tion"]   #a list of strings

word = "".join(syllables)

print(f'word = "{word}"')
word = "indignation"
syllables = ["in", "dig", "na", "tion"]

word = ""
for syllable in syllables:
    word = word + syllable   #or word += syllable

print(f'word = "{word}"')
word = "indignation"
import functools

syllables = ["in", "dig", "na", "tion"]

word = functools.reduce(lambda word, syllable: word + syllable, syllables, "")

print(f'word = "{word}"')
word = "indignation"
import functools
import operator

syllables = ["in", "dig", "na", "tion"]

word = functools.reduce(operator.concat, syllables, "")

print(f'word = "{word}"')
word = "indignation"

Three ways to compute the same duplication

The original list contains 5 items. The new list dupes contains 10 items.

[n, n] is a list containing two items. We append these two items to dupes, increasing the length of dupes by 2.

import functools

numbers = [10, 20, 30, 40, 50]

dupes = []
for n in numbers:
    dupes = dupes + [n, n]   #or dupes += [n, n] or dupes.extend([n, n])

print(dupes)   #a list of 10 ints
[10, 10, 20, 20, 30, 30, 40, 40, 50, 50]
import functools

numbers = [10, 20, 30, 40, 50]

dupes = functools.reduce(lambda dupes, n: dupes + [n, n], numbers, [])

print(dupes)   #a list of 10 ints
[10, 10, 20, 20, 30, 30, 40, 40, 50, 50]
numbers = [10, 20, 30, 40, 50]
dupes = [number for number in numbers for _ in range(2)]
print(dupes)   #a list of 10 ints
[10, 10, 20, 20, 30, 30, 40, 40, 50, 50]

The union and intersection of two sets

southernStates = {   #a set of strings
    "Alabama",
    "Mississippi",
    "Texas"
}

westernStates = {
    "Texas",
    "New Mexico",
    "Arizona",
    "Colorado",
    "Wyoming"
}

union        = southernStates | westernStates        #union is a set of strings
intersection = southernStates & westernStates        #intersection is a set of strings

print("The union of the 2 sets:")
for i, state in enumerate(sorted(union), start = 1): #sorted(union) is a list of strings
    print(f"{i} {state}")
print()

print("The intersection of the 2 sets:")
for i, state in enumerate(sorted(intersection), start = 1): #sorted(intersection) is a list of strings
    print(f"{i} {state}")
The union of the 2 sets:
1 Alabama
2 Arizona
3 Colorado
4 Mississippi
5 New Mexico
6 Texas
7 Wyoming

The intersection of the 2 sets:
1 Texas

The union of many sets

You can take the union of an empty list of sets. I therefore gave the third argument set() to the following call to functools.reduce, just in case the list was empty. The value of the expression set() is the empty set.

import functools

southernStates = {   #a set of strings
    "Alabama",
    "Mississippi",
    "Texas"
}

westernStates = {
    "Texas",
    "New Mexico",
    "Arizona",
    "Colorado",
    "Wyoming"
}

bigStates = {
    "Alaska",
    "Texas" ,
    "Montana"
}

tStates = {
    "Tennessee",
    "Texas"
}

categories = [   #a list of four sets
    southernStates,
    westernStates,
    bigStates,
    tStates
]

union = set()   #Start with an empty set.
for category in categories:
    union = union | category   #or union |= category

for i, state in enumerate(sorted(union), start = 1):
    print(f"{i:2} {state}")

print()

union = functools.reduce(lambda union, category: union | category, categories, set())

for i, state in enumerate(sorted(union), start = 1):
    print(f"{i:2} {state}")
 1 Alabama
 2 Alaska
 3 Arizona
 4 Colorado
 5 Mississippi
 6 Montana
 7 New Mexico
 8 Tennessee
 9 Texas
10 Wyoming

 1 Alabama
 2 Alaska
 3 Arizona
 4 Colorado
 5 Mississippi
 6 Montana
 7 New Mexico
 8 Tennessee
 9 Texas
10 Wyoming

The intersection of many sets

But there is no such thing as the intersection of an empty list of sets. (Similarly, there is no such thing as division by zero.) See Nullary intersection. That’s why I gave no third argument to the following call to functools.reduce, and why I wrote the assert. The assert verifies that it’s safe to use categories[0] as my orange thing.

import functools

southernStates = {   #a set of strings
    "Alabama",
    "Mississippi",
    "Texas"
}

westernStates = {
    "Texas",
    "New Mexico",
    "Arizona",
    "Colorado",
    "Wyoming"
}

bigStates = {
    "Alaska",
    "Texas" ,
    "Montana"
}

tStates = {
    "Tennessee",
    "Texas"
}

categories = [   #a list of four sets
    southernStates,
    westernStates,
    bigStates,
    tStates
]

assert len(categories) > 0

intersection = categories[0]
for category in categories[1:]:
    intersection = intersection & category   #or intersection &= category

for i, state in enumerate(sorted(intersection), start = 1):
    print(f"{i:2} {state}")

print()

intersection = functools.reduce(lambda intersection, category: intersection & category, categories)

for i, state in enumerate(sorted(intersection), start = 1):
    print(f"{i:2} {state}")
 1 Texas

 1 Texas

Flatten a deeply nested list

"""
Three ways to flatten a deeply nested list.  All three functions do the same thing.
"""

import sys
import functools

mylist = [
    [
        [
            [10, 20],
            [30, 40],
        ],
        [
            [50, 60],
            [70, 80],
        ],
    ],
    [
        [
            [90, 100],
            [110, 120],
        ],
        [
            [130, 140],
            [150, 160],
        ],
    ],
]

def recursiveFlatten1(originalList):
    assert isinstance(originalList, list)
    flattenedList = []
    for item in originalList:
        if isinstance(item, list):
            flattenedList += recursiveFlatten1(item) #or flattenedList.extend(recursiveFlatten1(item))
        else:
            flattenedList += [item]                  #or flattenedList.append(item)
    return flattenedList


def recursiveFlatten2(originalList): #recursiveFlatten1 rewritten with conditional expression instead of if/else.
    assert isinstance(originalList, list)
    flattenedList = []
    for item in originalList:
        flattenedList = flattenedList + (recursiveFlatten2(item) if isinstance(item, list) else [item])
    return flattenedList


def recursiveFlatten3(originalList): #recursiveFlatten2 rewritten with functools.reduce instead of for loop.
    assert isinstance(originalList, list)
    return functools.reduce(
        lambda flattenedList, item: flattenedList + (recursiveFlatten3(item) if isinstance(item, list) else [item]),
        originalList,
        []
    )

print(recursiveFlatten1(mylist))
print(recursiveFlatten2(mylist))
print(recursiveFlatten3(mylist))
sys.exit(0)
[10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160]
[10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160]
[10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160]

Things to try

  1. What does the following code put into newNumbers? What would be a simpler way to do the same thing?
    import functools
    
    numbers = [10, 20, 30, 40, 50]
    
    newNumbers = functools.reduce(lambda newNumbers, item: [item] + newNumbers, numbers, [])
    
    print(newNumbers)
    
  2. Use a list comprehension or the map function to change a list of sequences of bytes into a list of strings of characters.