# mypy: allow-untyped-defs import collections import operator from collections.abc import Mapping from functools import reduce __all__ = [ "merge", "merge_with", "valmap", "keymap", "itemmap", "valfilter", "keyfilter", "itemfilter", "assoc", "dissoc", "assoc_in", "update_in", "get_in", ] def _get_factory(f, kwargs): factory = kwargs.pop("factory", dict) if kwargs: raise TypeError( f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'" ) return factory def merge(*dicts, **kwargs): """Merge a collection of dictionaries >>> merge({1: "one"}, {2: "two"}) {1: 'one', 2: 'two'} Later dictionaries have precedence >>> merge({1: 2, 3: 4}, {3: 3, 4: 4}) {1: 2, 3: 3, 4: 4} See Also: merge_with """ if len(dicts) == 1 and not isinstance(dicts[0], Mapping): dicts = dicts[0] factory = _get_factory(merge, kwargs) rv = factory() for d in dicts: rv.update(d) return rv def merge_with(func, *dicts, **kwargs): """Merge dictionaries and apply function to combined values A key may occur in more than one dict, and all values mapped from the key will be passed to the function as a list, such as func([val1, val2, ...]). >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20}) {1: 11, 2: 22} >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP {1: 1, 2: 2, 3: 30} See Also: merge """ if len(dicts) == 1 and not isinstance(dicts[0], Mapping): dicts = dicts[0] factory = _get_factory(merge_with, kwargs) result = factory() for d in dicts: for k, v in d.items(): if k not in result: result[k] = [v] else: result[k].append(v) return valmap(func, result, factory) def valmap(func, d, factory=dict): """Apply function to values of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} >>> valmap(sum, bills) # doctest: +SKIP {'Alice': 65, 'Bob': 45} See Also: keymap itemmap """ rv = factory() rv.update(zip(d.keys(), map(func, d.values()))) return rv def keymap(func, d, factory=dict): """Apply function to keys of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} >>> keymap(str.lower, bills) # doctest: +SKIP {'alice': [20, 15, 30], 'bob': [10, 35]} See Also: valmap itemmap """ rv = factory() rv.update(zip(map(func, d.keys()), d.values())) return rv def itemmap(func, d, factory=dict): """Apply function to items of dictionary >>> accountids = {"Alice": 10, "Bob": 20} >>> itemmap(reversed, accountids) # doctest: +SKIP {10: "Alice", 20: "Bob"} See Also: keymap valmap """ rv = factory() rv.update(map(func, d.items())) return rv def valfilter(predicate, d, factory=dict): """Filter items in dictionary by value >>> iseven = lambda x: x % 2 == 0 >>> d = {1: 2, 2: 3, 3: 4, 4: 5} >>> valfilter(iseven, d) {1: 2, 3: 4} See Also: keyfilter itemfilter valmap """ rv = factory() for k, v in d.items(): if predicate(v): rv[k] = v return rv def keyfilter(predicate, d, factory=dict): """Filter items in dictionary by key >>> iseven = lambda x: x % 2 == 0 >>> d = {1: 2, 2: 3, 3: 4, 4: 5} >>> keyfilter(iseven, d) {2: 3, 4: 5} See Also: valfilter itemfilter keymap """ rv = factory() for k, v in d.items(): if predicate(k): rv[k] = v return rv def itemfilter(predicate, d, factory=dict): """Filter items in dictionary by item >>> def isvalid(item): ... k, v = item ... return k % 2 == 0 and v < 4 >>> d = {1: 2, 2: 3, 3: 4, 4: 5} >>> itemfilter(isvalid, d) {2: 3} See Also: keyfilter valfilter itemmap """ rv = factory() for item in d.items(): if predicate(item): k, v = item rv[k] = v return rv def assoc(d, key, value, factory=dict): """Return a new dict with new key value pair New dict has d[key] set to value. Does not modify the initial dictionary. >>> assoc({"x": 1}, "x", 2) {'x': 2} >>> assoc({"x": 1}, "y", 3) # doctest: +SKIP {'x': 1, 'y': 3} """ d2 = factory() d2.update(d) d2[key] = value return d2 def dissoc(d, *keys, **kwargs): """Return a new dict with the given key(s) removed. New dict has d[key] deleted for each supplied key. Does not modify the initial dictionary. >>> dissoc({"x": 1, "y": 2}, "y") {'x': 1} >>> dissoc({"x": 1, "y": 2}, "y", "x") {} >>> dissoc({"x": 1}, "y") # Ignores missing keys {'x': 1} """ factory = _get_factory(dissoc, kwargs) d2 = factory() if len(keys) < len(d) * 0.6: d2.update(d) for key in keys: if key in d2: del d2[key] else: remaining = set(d) remaining.difference_update(keys) for k in remaining: d2[k] = d[k] return d2 def assoc_in(d, keys, value, factory=dict): """Return a new dict with new, potentially nested, key value pair >>> purchase = { ... "name": "Alice", ... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, ... "credit card": "5555-1234-1234-1234", ... } >>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP {'credit card': '5555-1234-1234-1234', 'name': 'Alice', 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} """ return update_in(d, keys, lambda x: value, value, factory) def update_in(d, keys, func, default=None, factory=dict): """Update value in a (potentially) nested dictionary inputs: d - dictionary on which to operate keys - list or tuple giving the location of the value to be changed in d func - function to operate on that value If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the original dictionary with v replaced by func(v), but does not mutate the original dictionary. If k0 is not a key in d, update_in creates nested dictionaries to the depth specified by the keys, with the innermost value set to func(default). >>> inc = lambda x: x + 1 >>> update_in({"a": 0}, ["a"], inc) {'a': 1} >>> transaction = { ... "name": "Alice", ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, ... "credit card": "5555-1234-1234-1234", ... } >>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP {'credit card': '5555-1234-1234-1234', 'name': 'Alice', 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} >>> # updating a value when k0 is not in d >>> update_in({}, [1, 2, 3], str, default="bar") {1: {2: {3: 'bar'}}} >>> update_in({1: "foo"}, [2, 3, 4], inc, 0) {1: 'foo', 2: {3: {4: 1}}} """ ks = iter(keys) k = next(ks) rv = inner = factory() rv.update(d) # pyrefly: ignore [not-iterable] for key in ks: if k in d: d = d[k] dtemp = factory() dtemp.update(d) else: d = dtemp = factory() inner[k] = inner = dtemp k = key if k in d: inner[k] = func(d[k]) else: inner[k] = func(default) return rv def get_in(keys, coll, default=None, no_default=False): """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless ``no_default`` is specified, then it raises KeyError or IndexError. ``get_in`` is a generalization of ``operator.getitem`` for nested data structures such as dictionaries and lists. >>> transaction = { ... "name": "Alice", ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, ... "credit card": "5555-1234-1234-1234", ... } >>> get_in(["purchase", "items", 0], transaction) 'Apple' >>> get_in(["name"], transaction) 'Alice' >>> get_in(["purchase", "total"], transaction) >>> get_in(["purchase", "items", "apple"], transaction) >>> get_in(["purchase", "items", 10], transaction) >>> get_in(["purchase", "total"], transaction, 0) 0 >>> get_in(["y"], {}, no_default=True) Traceback (most recent call last): ... KeyError: 'y' See Also: itertoolz.get operator.getitem """ try: return reduce(operator.getitem, keys, coll) except (KeyError, IndexError, TypeError): if no_default: raise return default def getter(index): if isinstance(index, list): if len(index) == 1: index = index[0] return lambda x: (x[index],) elif index: return operator.itemgetter(*index) else: return lambda x: () else: return operator.itemgetter(index) def groupby(key, seq): """Group a collection by a key function >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] >>> groupby(len, names) # doctest: +SKIP {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} >>> iseven = lambda x: x % 2 == 0 >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} Non-callable keys imply grouping on a member. >>> groupby( ... "gender", ... [ ... {"name": "Alice", "gender": "F"}, ... {"name": "Bob", "gender": "M"}, ... {"name": "Charlie", "gender": "M"}, ... ], ... ) # doctest:+SKIP {'F': [{'gender': 'F', 'name': 'Alice'}], 'M': [{'gender': 'M', 'name': 'Bob'}, {'gender': 'M', 'name': 'Charlie'}]} Not to be confused with ``itertools.groupby`` See Also: countby """ if not callable(key): key = getter(key) d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] for item in seq: d[key(item)](item) rv = {} for k, v in d.items(): rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] return rv def first(seq): """The first element in a sequence >>> first("ABC") 'A' """ return next(iter(seq))