Skip to content

Commit

Permalink
Merge pull request #9 from csinva/brl-discretization-fixes
Browse files Browse the repository at this point in the history
fixed incompatibility with fpgrowth input format + brl integration tests
  • Loading branch information
csinva authored Nov 14, 2020
2 parents 3180fdb + 7ba68d9 commit 87b75b4
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install git+https://github.com/csinva/imodels
pip install .
- name: Test with pytest
run: |
pytest
43 changes: 25 additions & 18 deletions imodels/rule_list/bayesian_rule_list/bayesian_rule_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,34 +152,36 @@ def fit(self, X, y, feature_labels=[], undiscretized_features=[], verbose=False)
y = y.values

X, y = self._setdata(X, y, feature_labels, undiscretized_features)

permsdic = defaultdict(default_permsdic) # We will store here the MCMC results

data = list(X[:])

# Now find frequent itemsets
# Mine separately for each class
data_pos = [x for i, x in enumerate(data) if y[i] == 0]
data_neg = [x for i, x in enumerate(data) if y[i] == 1]
assert len(data_pos) + len(data_neg) == len(data)

X_df = pd.DataFrame(X, columns=feature_labels)
itemsets_df = fpgrowth(X_df, min_support=(self.minsupport / len(X)), max_len=self.maxcardinality)
X_colname_removed = data.copy()
for i in range(len(data)):
X_colname_removed[i] = list(map(lambda s: s.split(' : ')[1], X_colname_removed[i]))

X_df_categorical = pd.DataFrame(X_colname_removed, columns=feature_labels)
X_df_onehot = pd.get_dummies(X_df_categorical)
onehot_features = X_df_onehot.columns

itemsets_df = fpgrowth(X_df_onehot, min_support=(self.minsupport / len(X)), max_len=self.maxcardinality)
itemsets_indices = [tuple(s[1]) for s in itemsets_df.values]
itemsets = [np.array(feature_labels)[list(inds)] for inds in itemsets_indices]
itemsets = [np.array(onehot_features)[list(inds)] for inds in itemsets_indices]
itemsets = list(map(tuple, itemsets))
if self.verbose:
print(len(itemsets), 'rules mined')


# Now form the data-vs.-lhs set
# X[j] is the set of data points that contain itemset j (that is, satisfy rule j)
for c in X_df.columns:
X_df[c] = [c if x == 1 else '' for x in list(X_df[c])]
for c in X_df_onehot.columns:
X_df_onehot[c] = [c if x == 1 else '' for x in list(X_df_onehot[c])]
X = [{}] * (len(itemsets) + 1)
X[0] = set(range(len(data))) # the default rule satisfies all data
for (j, lhs) in enumerate(itemsets):
X[j + 1] = set([i for (i, xi) in enumerate(X_df.values) if set(lhs).issubset(xi)])
X[j + 1] = set([i for (i, xi) in enumerate(X_df_onehot.values) if set(lhs).issubset(xi)])



# now form lhs_len
Expand Down Expand Up @@ -264,15 +266,20 @@ def __str__(self, decimals=1):
return "(Untrained RuleListClassifier)"

def _to_itemset_indices(self, data):
X_colname_removed = data.copy()
for i in range(len(data)):
X_colname_removed[i] = list(map(lambda s: s.split(' : ')[1], X_colname_removed[i]))
X_df_categorical = pd.DataFrame(X_colname_removed, columns=self.feature_labels)
X_df_onehot = pd.get_dummies(X_df_categorical)

# X[j] is the set of data points that contain itemset j (that is, satisfy rule j)
X_df = pd.DataFrame(data, columns=self.feature_labels)
for c in X_df.columns:
X_df[c] = [c if x == 1 else '' for x in list(X_df[c])]
for c in X_df_onehot.columns:
X_df_onehot[c] = [c if x == 1 else '' for x in list(X_df_onehot[c])]
X = [set() for j in range(len(self.itemsets))]
X[0] = set(range(len(data))) # the default rule satisfies all data
for (j, lhs) in enumerate(self.itemsets):
if j > 0:
X[j] = set([i for (i, xi) in enumerate(X_df.values) if set(lhs).issubset(xi)])
X[j] = set([i for (i, xi) in enumerate(X_df_onehot.values) if set(lhs).issubset(xi)])
return X

def predict_proba(self, X):
Expand All @@ -296,7 +303,7 @@ def predict_proba(self, X):
if self.discretizer:
self.discretizer._data = pd.DataFrame(X, columns=self.feature_labels)
self.discretizer.apply_cutpoints()
D = self._prepend_feature_labels(np.array(self.discretizer._data)[:, :-1])
D = self._prepend_feature_labels(np.array(self.discretizer._data))
else:
D = X

Expand Down
44 changes: 44 additions & 0 deletions tests/brl_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

from imodels.rule_list.bayesian_rule_list.bayesian_rule_list import BayesianRuleListClassifier


class TestBRL(unittest.TestCase):

def test_integration_stability(self):
X = [[0, 0, 1, 1, 0],
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[1, 0, 0, 0, 0],
[1, 1, 0, 1, 1],
[1, 1, 1, 1, 1],
[0, 1, 1, 1, 1],
[1, 0, 1, 1, 1]]
y = [0, 0, 0, 0, 1, 1, 1, 1]
M = BayesianRuleListClassifier(minsupport=2)
feat = ['ft1', 'ft2', 'ft3', 'ft4', 'ft5']
M.fit(X, y, feature_labels=feat)
assert [M.predict([row], threshold=0.5) for row in X] == y

def test_integration_fitting(self):
np.random.seed(13)
feature_labels = ["#Pregnant", "Glucose concentration test", "Blood pressure(mmHg)",
"Triceps skin fold thickness(mm)",
"2-Hour serum insulin (mu U/ml)", "Body mass index", "Diabetes pedigree function",
"Age (years)"]
data = fetch_openml("diabetes") # get dataset
X = data.data
y = (data.target == 'tested_positive').astype(np.int) # labels 0-1

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.75) # split

# train classifier (allow more iterations for better accuracy; use BigDataRuleListClassifier for large datasets)
print('training...')
model = BayesianRuleListClassifier(max_iter=1000, listlengthprior=5, class1label="diabetes", verbose=False)
model.fit(X_train, y_train, feature_labels=feature_labels)
preds = model.predict(X_test, threshold=0.1)
print("RuleListClassifier Accuracy:", np.mean(y_test == preds), "Learned interpretable model:\n", model)

0 comments on commit 87b75b4

Please sign in to comment.