Skip to content

Commit

Permalink
Merge pull request #25 from SonyCSLParis/roll_activations
Browse files Browse the repository at this point in the history
Roll activations instead of substracting shift, so that activations and predictions are correct
  • Loading branch information
aRI0U authored Jan 17, 2024
2 parents 2d11051 + 834ffa8 commit 8dc3bad
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 6 deletions.
3 changes: 3 additions & 0 deletions pesto/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .core import load_model, predict, predict_from_files


__version__ = '1.1.0'
1 change: 0 additions & 1 deletion pesto/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,3 @@ def predict_from_files(
predictions = [p.cpu().numpy() for p in predictions]
for fmt in export_format:
export(fmt, output_file, *predictions)

5 changes: 2 additions & 3 deletions pesto/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,9 @@ def forward(self,
if batch_size:
activations = activations.view(batch_size, -1, activations.size(-1))

preds = reduce_activations(activations, reduction=self.reduction)
activations = activations.roll(-round(self.shift.cpu().item() * self.bins_per_semitone), -1)

# decrease by shift to get absolute pitch
preds.sub_(self.shift)
preds = reduce_activations(activations, reduction=self.reduction)

if convert_to_freq:
preds = 440 * 2 ** ((preds - 69) / 12)
Expand Down
4 changes: 3 additions & 1 deletion pesto/utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def export_png(output_file: str, timesteps, confidence, activations, lims=(21, 1
activations = activations * confidence[:, None]
plt.imshow(activations.T,
aspect='auto', origin='lower', cmap='inferno',
extent=(timesteps[0], timesteps[-1]) + lims)
extent=(timesteps[0] / 1000, timesteps[-1] / 1000) + lims)

plt.xlabel("Time (s)")
plt.ylabel("Pitch (semitones)")
plt.title(output_file.rsplit('.', 2)[0])
plt.tight_layout()
plt.savefig(output_file)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pesto-pitch"
version = "1.0.0"
dynamic = ["version"]
authors = [
{name = "Alain Riou", email = "[email protected]"}
]
Expand Down Expand Up @@ -41,5 +41,8 @@ source = "https://github.com/SonyCSLParis/pesto"
[tool.pytest.ini_options]
testpaths = "tests/"

[tool.setuptools.dynamic]
version = {attr = "pesto.__version__"}

[tool.setuptools.package-data]
pesto = ["weights/*.ckpt"]

0 comments on commit 8dc3bad

Please sign in to comment.