Skip to content

Commit

Permalink
REF: share hashtable_func_helper code (pandas-dev#46090)
Browse files Browse the repository at this point in the history
* REF: hashtable_func_helper

* REF: de-duplicate using to_c_type pattern
  • Loading branch information
jbrockmendel authored Feb 24, 2022
1 parent 2fe0c70 commit eaefc5c
Showing 1 changed file with 24 additions and 31 deletions.
55 changes: 24 additions & 31 deletions pandas/_libs/hashtable_func_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dtypes = [('Complex128', 'complex128', 'complex128',
('UInt32', 'uint32', 'uint32', 'uint32_t', ''),
('UInt16', 'uint16', 'uint16', 'uint16_t', ''),
('UInt8', 'uint8', 'uint8', 'uint8_t', ''),
('Object', 'object', 'pymap', 'object', ''),
('Object', 'object', 'pymap', 'object', '<PyObject*>'),
('Int64', 'int64', 'int64', 'int64_t', ''),
('Int32', 'int32', 'int32', 'int32_t', ''),
('Int16', 'int16', 'int16', 'int16_t', ''),
Expand Down Expand Up @@ -61,11 +61,11 @@ cdef value_count_{{dtype}}(const {{dtype}}_t[:] values, bint dropna):
for i in range(n):
val = values[i]
if not dropna or not checknull(val):
k = kh_get_{{ttype}}(table, <PyObject*>val)
k = kh_get_{{ttype}}(table, {{to_c_type}}val)
if k != table.n_buckets:
table.vals[k] += 1
else:
k = kh_put_{{ttype}}(table, <PyObject*>val, &ret)
k = kh_put_{{ttype}}(table, {{to_c_type}}val, &ret)
table.vals[k] = 1
result_keys.append(val)
{{else}}
Expand Down Expand Up @@ -110,6 +110,8 @@ cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first'):
int ret = 0
{{if dtype != 'object'}}
{{c_type}} value
{{else}}
PyObject* value
{{endif}}
Py_ssize_t i, n = len(values)
khiter_t k
Expand All @@ -123,44 +125,33 @@ cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first'):

if keep == 'last':
{{if dtype == 'object'}}
for i in range(n - 1, -1, -1):
# equivalent: range(n)[::-1], which cython doesn't like in nogil
kh_put_{{ttype}}(table, <PyObject*>values[i], &ret)
out[i] = ret == 0
if True:
{{else}}
with nogil:
{{endif}}
for i in range(n - 1, -1, -1):
# equivalent: range(n)[::-1], which cython doesn't like in nogil
value = {{to_c_type}}(values[i])
kh_put_{{ttype}}(table, value, &ret)
out[i] = ret == 0
{{endif}}

elif keep == 'first':
{{if dtype == 'object'}}
for i in range(n):
kh_put_{{ttype}}(table, <PyObject*>values[i], &ret)
out[i] = ret == 0
if True:
{{else}}
with nogil:
{{endif}}
for i in range(n):
value = {{to_c_type}}(values[i])
kh_put_{{ttype}}(table, value, &ret)
out[i] = ret == 0
{{endif}}

else:
{{if dtype == 'object'}}
for i in range(n):
value = values[i]
k = kh_get_{{ttype}}(table, <PyObject*>value)
if k != table.n_buckets:
out[table.vals[k]] = 1
out[i] = 1
else:
k = kh_put_{{ttype}}(table, <PyObject*>value, &ret)
table.vals[k] = i
out[i] = 0
if True:
{{else}}
with nogil:
{{endif}}
for i in range(n):
value = {{to_c_type}}(values[i])
k = kh_get_{{ttype}}(table, value)
Expand All @@ -171,7 +162,7 @@ cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first'):
k = kh_put_{{ttype}}(table, value, &ret)
table.vals[k] = i
out[i] = 0
{{endif}}

kh_destroy_{{ttype}}(table)
return out

Expand Down Expand Up @@ -206,39 +197,41 @@ cdef ismember_{{dtype}}(const {{dtype}}_t[:] arr, const {{dtype}}_t[:] values):
khiter_t k
int ret = 0
ndarray[uint8_t] result

{{if dtype == "object"}}
PyObject* val
{{else}}
{{c_type}} val
{{endif}}

kh_{{ttype}}_t *table = kh_init_{{ttype}}()

# construct the table
n = len(values)
kh_resize_{{ttype}}(table, n)

{{if dtype == 'object'}}
for i in range(n):
kh_put_{{ttype}}(table, <PyObject*>values[i], &ret)
if True:
{{else}}
with nogil:
{{endif}}
for i in range(n):
val = {{to_c_type}}(values[i])
kh_put_{{ttype}}(table, val, &ret)
{{endif}}

# test membership
n = len(arr)
result = np.empty(n, dtype=np.uint8)

{{if dtype == 'object'}}
for i in range(n):
val = arr[i]
k = kh_get_{{ttype}}(table, <PyObject*>val)
result[i] = (k != table.n_buckets)
if True:
{{else}}
with nogil:
{{endif}}
for i in range(n):
val = {{to_c_type}}(arr[i])
k = kh_get_{{ttype}}(table, val)
result[i] = (k != table.n_buckets)
{{endif}}

kh_destroy_{{ttype}}(table)
return result.view(np.bool_)
Expand Down

0 comments on commit eaefc5c

Please sign in to comment.