Skip to content

Commit

Permalink
add enter parameter to research to allow traversing custom data types
Browse files Browse the repository at this point in the history
  • Loading branch information
mahmoud committed Jun 30, 2024
1 parent d9a927b commit 1558848
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 34 deletions.
109 changes: 75 additions & 34 deletions boltons/iterutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,15 @@ def split_iter(src, sep=None, maxsplit=None):
sep_func = sep
elif not is_scalar(sep):
sep = frozenset(sep)
sep_func = lambda x: x in sep
def sep_func(x): return x in sep
else:
sep_func = lambda x: x == sep
def sep_func(x): return x == sep

cur_group = []
split_count = 0
for s in src:
if maxsplit is not None and split_count >= maxsplit:
sep_func = lambda x: False
def sep_func(x): return False
if sep_func(s):
if sep is None and not cur_group:
# If sep is none, str.split() "groups" separators
Expand Down Expand Up @@ -229,7 +229,7 @@ def rstrip(iterable, strip_value=None):
['Foo', 'Bar']
"""
return list(rstrip_iter(iterable,strip_value))
return list(rstrip_iter(iterable, strip_value))


def rstrip_iter(iterable, strip_value=None):
Expand All @@ -253,7 +253,7 @@ def rstrip_iter(iterable, strip_value=None):
else:
broken = True
break
if not broken: # Return to caller here because the end of the
if not broken: # Return to caller here because the end of the
return # iterator has been reached
yield from cache
yield i
Expand All @@ -268,10 +268,10 @@ def strip(iterable, strip_value=None):
['Foo', 'Bar', 'Bam']
"""
return list(strip_iter(iterable,strip_value))
return list(strip_iter(iterable, strip_value))


def strip_iter(iterable,strip_value=None):
def strip_iter(iterable, strip_value=None):
"""Strips values from the beginning and end of an iterable. Stripped items
will match the value of the argument strip_value. Functionality is
analogous to that of the method str.strip. Returns a generator.
Expand All @@ -280,7 +280,7 @@ def strip_iter(iterable,strip_value=None):
['Foo', 'Bar', 'Bam']
"""
return rstrip_iter(lstrip_iter(iterable,strip_value),strip_value)
return rstrip_iter(lstrip_iter(iterable, strip_value), strip_value)


def chunked(src, size, count=None, **kw):
Expand Down Expand Up @@ -340,11 +340,12 @@ def chunked_iter(src, size, **kw):
raise ValueError('got unexpected keyword arguments: %r' % kw.keys())
if not src:
return
postprocess = lambda chk: chk

def postprocess(chk): return chk
if isinstance(src, (str, bytes)):
postprocess = lambda chk, _sep=type(src)(): _sep.join(chk)
def postprocess(chk, _sep=type(src)()): return _sep.join(chk)
if isinstance(src, bytes):
postprocess = lambda chk: bytes(chk)
def postprocess(chk): return bytes(chk)
src_iter = iter(src)
while True:
cur_chunk = list(itertools.islice(src_iter, size))
Expand Down Expand Up @@ -385,15 +386,19 @@ def chunk_ranges(input_size, chunk_size, input_offset=0, overlap_size=0, align=F
>>> list(chunk_ranges(input_offset=3, input_size=15, chunk_size=5, overlap_size=1, align=True))
[(3, 5), (4, 9), (8, 13), (12, 17), (16, 18)]
"""
input_size = _validate_positive_int(input_size, 'input_size', strictly_positive=False)
input_size = _validate_positive_int(
input_size, 'input_size', strictly_positive=False)
chunk_size = _validate_positive_int(chunk_size, 'chunk_size')
input_offset = _validate_positive_int(input_offset, 'input_offset', strictly_positive=False)
overlap_size = _validate_positive_int(overlap_size, 'overlap_size', strictly_positive=False)
input_offset = _validate_positive_int(
input_offset, 'input_offset', strictly_positive=False)
overlap_size = _validate_positive_int(
overlap_size, 'overlap_size', strictly_positive=False)

input_stop = input_offset + input_size

if align:
initial_chunk_len = chunk_size - input_offset % (chunk_size - overlap_size)
initial_chunk_len = chunk_size - \
input_offset % (chunk_size - overlap_size)
if initial_chunk_len != overlap_size:
yield (input_offset, min(input_offset + initial_chunk_len, input_stop))
if input_offset + initial_chunk_len >= input_stop:
Expand Down Expand Up @@ -479,7 +484,7 @@ def windowed_iter(src, size, fill=_UNSET):
With *fill* set, the iterator always yields a number of windows
equal to the length of the *src* iterable.
>>> windowed(range(4), 3, fill=None)
[(0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
Expand All @@ -495,17 +500,16 @@ def windowed_iter(src, size, fill=_UNSET):
except StopIteration:
return zip([])
return zip(*tees)

for i, t in enumerate(tees):
for _ in range(i):
for _ in range(i):
try:
next(t)
except StopIteration:
continue
return zip_longest(*tees, fillvalue=fill)



def xfrange(stop, start=None, step=1.0):
"""Same as :func:`frange`, but generator-based instead of returning a
list.
Expand Down Expand Up @@ -726,21 +730,21 @@ def bucketize(src, key=bool, value_transform=None, key_filter=None):
src = zip(key, src)

if isinstance(key, str):
key_func = lambda x: getattr(x, key, x)
def key_func(x): return getattr(x, key, x)
elif callable(key):
key_func = key
elif isinstance(key, list):
key_func = lambda x: x[0]
def key_func(x): return x[0]
else:
raise TypeError('expected key to be callable or a string or a list')

if value_transform is None:
value_transform = lambda x: x
def value_transform(x): return x
if not callable(value_transform):
raise TypeError('expected callable value transform function')
if isinstance(key, list):
f = value_transform
value_transform=lambda x: f(x[1])
def value_transform(x): return f(x[1])

ret = {}
for val in src:
Expand Down Expand Up @@ -807,11 +811,11 @@ def unique_iter(src, key=None):
if not is_iterable(src):
raise TypeError('expected an iterable, not %r' % type(src))
if key is None:
key_func = lambda x: x
def key_func(x): return x
elif callable(key):
key_func = key
elif isinstance(key, str):
key_func = lambda x: getattr(x, key, x)
def key_func(x): return getattr(x, key, x)
else:
raise TypeError('"key" expected a string or callable, not %r' % key)
seen = set()
Expand Down Expand Up @@ -862,7 +866,7 @@ def redundant(src, key=None, groups=False):
elif callable(key):
key_func = key
elif isinstance(key, (str, bytes)):
key_func = lambda x: getattr(x, key, x)
def key_func(x): return getattr(x, key, x)
else:
raise TypeError('"key" expected a string or callable, not %r' % key)
seen = {} # key to first seen item
Expand Down Expand Up @@ -964,6 +968,7 @@ def flatten_iter(iterable):
else:
yield item


def flatten(iterable):
"""``flatten()`` returns a collapsed list of all the elements from
*iterable* while collapsing any nested iterables.
Expand Down Expand Up @@ -1006,6 +1011,7 @@ def default_visit(path, key, value):
# print('visit(%r, %r, %r)' % (path, key, value))
return key, value


# enable the extreme: monkeypatching iterutils with a different default_visit
_orig_default_visit = default_visit

Expand Down Expand Up @@ -1128,6 +1134,9 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
callable. When set to ``False``, remap ignores any errors
raised by the *visit* callback. Items causing exceptions
are kept. See examples for more details.
trace (bool): Pass ``trace=True`` to print out the entire
traversal. Or pass a tuple of ``'visit'``, ``'enter'``,
or ``'exit'`` to print only the selected events.
remap is designed to cover the majority of cases with just the
*visit* callable. While passing in multiple callables is very
Expand Down Expand Up @@ -1156,6 +1165,15 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
if not callable(exit):
raise TypeError('exit expected callable, not: %r' % exit)
reraise_visit = kwargs.pop('reraise_visit', True)
trace = kwargs.pop('trace', ())
if trace is True:
trace = ('visit', 'enter', 'exit')
elif isinstance(trace, str):
trace = (trace,)
if not isinstance(trace, (tuple, list, set)):
raise TypeError('trace expected tuple of event names, not: %r' % trace)
trace_enter, trace_exit, trace_visit = 'enter' in trace, 'exit' in trace, 'visit' in trace

if kwargs:
raise TypeError('unexpected keyword arguments: %r' % kwargs.keys())

Expand All @@ -1168,14 +1186,23 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
key, new_parent, old_parent = value
id_value = id(old_parent)
path, new_items = new_items_stack.pop()
if trace_exit:
print(' .. remap exit:', path, '-', key, '-',
old_parent, '-', new_parent, '-', new_items)
value = exit(path, key, old_parent, new_parent, new_items)
if trace_exit:
print(' .. remap exit result:', value)
registry[id_value] = value
if not new_items_stack:
continue
elif id_value in registry:
value = registry[id_value]
else:
if trace_enter:
print(' .. remap enter:', path, '-', key, '-', value)
res = enter(path, key, value)
if trace_enter:
print(' .. remap enter result:', res)
try:
new_parent, new_items = res
except TypeError:
Expand All @@ -1191,21 +1218,29 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
stack.append((_REMAP_EXIT, (key, new_parent, value)))
if new_items:
stack.extend(reversed(list(new_items)))
if trace_enter:
print(' .. remap stack size now:', len(stack))
continue
if visit is _orig_default_visit:
# avoid function call overhead by inlining identity operation
visited_item = (key, value)
else:
try:
if trace_visit:
print(' .. remap visit:', path, '-', key, '-', value)
visited_item = visit(path, key, value)
except Exception:
if reraise_visit:
raise
visited_item = True
if visited_item is False:
if trace_visit:
print(' .. remap visit result: <drop>')
continue # drop
elif visited_item is True:
visited_item = (key, value)
if trace_visit:
print(' .. remap visit result:', visited_item)
# TODO: typecheck?
# raise TypeError('expected (key, value) from visit(),'
# ' not: %r' % visited_item)
Expand All @@ -1221,6 +1256,7 @@ class PathAccessError(KeyError, IndexError, TypeError):
representing what can occur when looking up a path in a nested
object.
"""

def __init__(self, exc, seg, path):
self.exc = exc
self.seg = seg
Expand Down Expand Up @@ -1296,7 +1332,7 @@ def get_path(root, path, default=_UNSET):
return cur


def research(root, query=lambda p, k, v: True, reraise=False):
def research(root, query=lambda p, k, v: True, reraise=False, enter=default_enter):
"""The :func:`research` function uses :func:`remap` to recurse over
any data nested in *root*, and find values which match a given
criterion, specified by the *query* callable.
Expand Down Expand Up @@ -1343,16 +1379,16 @@ def research(root, query=lambda p, k, v: True, reraise=False):
if not callable(query):
raise TypeError('query expected callable, not: %r' % query)

def enter(path, key, value):
def _enter(path, key, value):
try:
if query(path, key, value):
ret.append((path + (key,), value))
except Exception:
if reraise:
raise
return default_enter(path, key, value)
return enter(path, key, value)

remap(root, enter=enter)
remap(root, enter=_enter)
return ret


Expand Down Expand Up @@ -1383,6 +1419,7 @@ class GUIDerator:
detect a fork on next iteration and reseed accordingly.
"""

def __init__(self, size=24):
self.size = size
if size < 20 or size > 36:
Expand Down Expand Up @@ -1495,13 +1532,16 @@ def soft_sorted(iterable, first=None, last=None, key=None, reverse=False):
last = last or []
key = key or (lambda x: x)
seq = list(iterable)
other = [x for x in seq if not ((first and key(x) in first) or (last and key(x) in last))]
other = [x for x in seq if not (
(first and key(x) in first) or (last and key(x) in last))]
other.sort(key=key, reverse=reverse)

if first:
first = sorted([x for x in seq if key(x) in first], key=lambda x: first.index(key(x)))
first = sorted([x for x in seq if key(x) in first],
key=lambda x: first.index(key(x)))
if last:
last = sorted([x for x in seq if key(x) in last], key=lambda x: last.index(key(x)))
last = sorted([x for x in seq if key(x) in last],
key=lambda x: last.index(key(x)))
return first + other + last


Expand Down Expand Up @@ -1536,7 +1576,7 @@ def __lt__(self, other):
ret = obj < other
except TypeError:
ret = ((type(obj).__name__, id(type(obj)), obj)
< (type(other).__name__, id(type(other)), other))
< (type(other).__name__, id(type(other)), other))
return ret

if key is not None and not callable(key):
Expand All @@ -1545,6 +1585,7 @@ def __lt__(self, other):

return sorted(iterable, key=_Wrapper, reverse=reverse)


"""
May actually be faster to do an isinstance check for a str path
Expand Down
22 changes: 22 additions & 0 deletions tests/test_iterutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,28 @@ def broken_query(p, k, v):
assert research(root, broken_query) == []


def test_research_custom_enter():
# see #368
from types import SimpleNamespace as NS
root = NS(
a='a',
b='b',
c=NS(aa='aa') )

def query(path, key, value):
return value.startswith('a')

def custom_enter(path, key, value):
if isinstance(value, NS):
return [], value.__dict__.items()
return default_enter(path, key, value)

with pytest.raises(TypeError):
research(root, query)
assert research(root, query, enter=custom_enter) == [(('a',), 'a'), (('c', 'aa'), 'aa')]



def test_backoff_basic():
from boltons.iterutils import backoff

Expand Down

0 comments on commit 1558848

Please sign in to comment.