Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve perforamance of _pluck_uniq_cols #216

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
7 changes: 4 additions & 3 deletions src/formpack/utils/expand_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .array_to_xpath import EXPANDABLE_FIELD_TYPES
from .future import iteritems, OrderedDict
from .iterator import get_first_occurrence
from .ordered_set import OrderedSet
from .replace_aliases import META_TYPES
from .string import str_types
from ..constants import (UNTRANSLATED, OR_OTHER_COLUMN,
Expand Down Expand Up @@ -158,7 +159,7 @@ def _get_special_survey_cols(content):
'hint::English',
For more examples, see tests.
"""
uniq_cols = OrderedDict()
uniq_cols = OrderedSet()
special = OrderedDict()

known_translated_cols = content.get('translated', [])
Expand All @@ -169,7 +170,7 @@ def _pluck_uniq_cols(sheet_name):
# to be parsed and translated in a previous iteration
_cols = [r for r in row.keys() if r not in known_translated_cols]

uniq_cols.update(OrderedDict.fromkeys(_cols))
uniq_cols.update(_cols)

def _mark_special(**kwargs):
column_name = kwargs.pop('column_name')
Expand All @@ -178,7 +179,7 @@ def _mark_special(**kwargs):
_pluck_uniq_cols('survey')
_pluck_uniq_cols('choices')

for column_name in uniq_cols.keys():
for column_name in uniq_cols:
if column_name in ['label', 'hint']:
_mark_special(column_name=column_name,
column=column_name,
Expand Down
45 changes: 45 additions & 0 deletions src/formpack/utils/ordered_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import collections


class OrderedSet(collections.MutableSet):
def __init__(self, iterable=None):
self.set = set()
self.list = []
if iterable is not None:
self |= iterable

def __len__(self):
return len(self.list)

def __contains__(self, key):
return key in self.set

def add(self, key):
if key not in self.set:
self.set.add(key)
self.list.append(key)

def update(self, keys):
for key in keys:
self.add(key)

def discard(self, key):
if key in self.set:
self.set.discard(key)
self.list.remove(key)

def __iter__(self):
curr = 0
while curr < len(self.list):
yield self.list[curr]
curr += 1

def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self))

def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)