Skip to content

Commit

Permalink
Merge pull request #234 from SABS-R3-Epidemiology/new_mers_figure
Browse files Browse the repository at this point in the history
new mers figure
  • Loading branch information
I-Bouros authored Nov 17, 2023
2 parents d14bd87 + 801fa99 commit 85cc68d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 13 deletions.
74 changes: 62 additions & 12 deletions branchpro/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def plot_regions_inference(first_day_data,
inset_region=[],
show=True,
mers=False,
hkhn=False):
hkhn=False,
bar=True,
separate_imported=False):
"""Make a figure showing R_t inference for different choices of epsilon and
regions.
Expand Down Expand Up @@ -421,13 +423,20 @@ def plot_regions_inference(first_day_data,
# Use 0.01 height ratio subplot rows to space out the panels
region_num = len(region_names)
fig = plt.figure()
gs = fig.add_gridspec(2, region_num, height_ratios=[1, 1])
if not separate_imported:
gs = fig.add_gridspec(2, region_num, height_ratios=[1, 1])
else:
gs = fig.add_gridspec(3, region_num, height_ratios=[1, 1, 1])

# Ax for case data
top_axs = [fig.add_subplot(gs[0, i]) for i in range(region_num)]

if separate_imported:
imported_axs = [fig.add_subplot(gs[1, i]) for i in range(region_num)]

# Axes for R_t inference
axs = [fig.add_subplot(gs[1, j]) for j in range(region_num)]
axs = [fig.add_subplot(gs[1 if not separate_imported else 2, j])
for j in range(region_num)]

# Make inference panel share x axis of its incidence data
for i in range(len(region_names)):
Expand All @@ -445,13 +454,21 @@ def plot_regions_inference(first_day_data,
first_day_data_r = first_day_data
data_times = [first_day_data_r + datetime.timedelta(days=int(i))
for i in range(len(local_cases[region]))]
top_axs[region].bar([x - width/2 for x in data_times],
local_cases[region],
width,
label='Local cases',
color='k',
alpha=0.8)
top_axs[region].bar([x + width/2 for x in data_times],

if not separate_imported:
imported_ax = top_axs[region]
else:
imported_ax = imported_axs[region]

if bar:
top_axs[region].bar([x - width/2 for x in data_times],
local_cases[region],
width,
label='Local cases',
color='k',
alpha=0.8)

imported_ax.bar([x + width/2 for x in data_times],
import_cases[region],
width,
hatch='/////',
Expand All @@ -460,8 +477,24 @@ def plot_regions_inference(first_day_data,
label='Imported cases',
color='deeppink',
zorder=10)
else:
top_axs[region].plot([x - width/2 for x in data_times],
local_cases[region],
lw=1.25,
label='Local cases',
color='k',
alpha=0.8,
zorder=9)
imported_ax.plot([x + width/2 for x in data_times],
import_cases[region],
lw=0.8,
label='Imported cases',
color='deeppink',
zorder=10)

top_axs[region].set_ylabel('Number of cases')
if separate_imported:
imported_ax.set_ylabel('Number of cases')

# Plot a zoomed in part of the graph as an inset
if not hkhn:
Expand Down Expand Up @@ -597,6 +630,9 @@ def plot_regions_inference(first_day_data,
# Add the legend for epsilons
top_axs[region].legend()

if separate_imported:
imported_axs[region].legend()

if hkhn:
if region == 1:
axs[region].legend([
Expand Down Expand Up @@ -637,6 +673,11 @@ def plot_regions_inference(first_day_data,
top_axs[i].set_xlabel('Date (2020)')
axs[i].set_xlabel('Date (2020)')

if separate_imported:
imported_axs[i].xaxis.set_major_formatter(
matplotlib.dates.DateFormatter('%b %d'))
imported_axs[i].set_xlabel('Date (2020)')

# Set ticks once per week
for j in range(region_num):
if hkhn:
Expand All @@ -663,12 +704,21 @@ def plot_regions_inference(first_day_data,
plt.sca(axs[i])
plt.xticks(rotation=45, ha='center')

if separate_imported:
plt.sca(imported_axs[i])
plt.xticks(rotation=45, ha='center')

for i in range(len(region_names)):
top_axs[i].set_title(region_names[i], fontsize=14)

# Add panel labels
fig.text(0.025, 0.965, '(a)', fontsize=14)
fig.text(0.025, 0.5, '(b)', fontsize=14)
if not separate_imported:
fig.text(0.025, 0.965, '(a)', fontsize=14)
fig.text(0.025, 0.5, '(b)', fontsize=14)
else:
fig.text(0.025, 0.965, '(a)', fontsize=14)
fig.text(0.025, 0.65, '(b)', fontsize=14)
fig.text(0.025, 0.35, '(c)', fontsize=14)

fig.set_size_inches(4 * region_num, 6)
fig.set_tight_layout(True)
Expand Down
35 changes: 35 additions & 0 deletions branchpro/tests/test_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,41 @@ def test_plot(self):
show=True,
hkhn=False)

fig = branchpro.figures.plot_regions_inference(
datetime.datetime(2020, 3, 1),
['Ontario', 'Ontario'],
[self.locally_infected_cases, self.locally_infected_cases],
[self.imported_cases, self.imported_cases],
datetime.datetime(2020, 3, 7),
self.epsilon_range,
[[self.all_intervals.loc[self.all_intervals['Epsilon'] == e]
for e in self.epsilon_range],
[self.all_intervals.loc[self.all_intervals['Epsilon'] == e]
for e in self.epsilon_range]],
default_epsilon=1,
inset_region=['Ontario'],
show=True,
hkhn=False,
bar=False)

branchpro.figures.plot_regions_inference(
datetime.datetime(2020, 3, 1),
['Ontario', 'Ontario'],
[self.locally_infected_cases, self.locally_infected_cases],
[self.imported_cases, self.imported_cases],
datetime.datetime(2020, 3, 7),
self.epsilon_range,
[[self.all_intervals.loc[self.all_intervals['Epsilon'] == e]
for e in self.epsilon_range],
[self.all_intervals.loc[self.all_intervals['Epsilon'] == e]
for e in self.epsilon_range]],
default_epsilon=1,
inset_region=['Ontario'],
show=True,
hkhn=False,
bar=False,
separate_imported=True)

# Check that all plots are present
assert len(fig.axes) == 4

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_readme():
"""
Load README.md text for use as description.
"""
with open('README.md') as f:
with open('README.md', encoding='utf-8') as f:
return f.read()


Expand Down

0 comments on commit 85cc68d

Please sign in to comment.