Skip to content

Commit

Permalink
fix performance issue in t-test
Browse files Browse the repository at this point in the history
  • Loading branch information
SamKarkache committed Apr 18, 2024
1 parent a07fe46 commit f6e16c0
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions WPI_SCA_LIBRARY/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def t_test_tvla(fixed_t: np.ndarray, random_t: np.ndarray, visualize: bool = Fal
new_sf_outer = []
new_sr_outer = []
t_max_outer = []
fixed_t = np.array(fixed_t, dtype=np.float64)
random_t = np.array(random_t, dtype=np.float64)

if len(fixed_t) != len(random_t):
raise ValueError("Length of fixed_t and random_t must be equal")
Expand All @@ -151,12 +153,13 @@ def t_test_intermediate(mf_old, mr_old, sf_old, sr_old, new_tf, new_tr, n):
new_sf = sf_old + (new_tf - mf_old) * (new_tf - new_mf)
new_sr = sr_old + (new_tr - mr_old) * (new_tr - new_mr)

new_stdf = np.sqrt(np.array(new_sf / n, dtype=np.float64))
new_stdr = np.sqrt(np.array(new_sr / n, dtype=np.float64))
new_stdf = np.sqrt(new_sf / n)
new_stdr = np.sqrt(new_sr / n)

try:
welsh_t = np.array(new_mr - new_mf) / np.sqrt(
np.array((new_stdr ** 2)) / (n + 1) + np.array((new_stdf ** 2)) / (n + 1))
with np.errstate(divide='ignore'):
welsh_t = np.array(new_mr - new_mf) / np.sqrt(
np.array((new_stdr ** 2)) / (n + 1) + np.array((new_stdf ** 2)) / (n + 1))
except ZeroDivisionError:
welsh_t = 0

Expand Down

0 comments on commit f6e16c0

Please sign in to comment.