Skip to content

Commit

Permalink
fix joblib model loading (#542)
Browse files Browse the repository at this point in the history
  • Loading branch information
amfonelic authored Sep 15, 2022
1 parent 3d26aa7 commit b0f76ef
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
11 changes: 9 additions & 2 deletions m2cgen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Model can also be piped:
# cat <path_to_file> | m2cgen --language java
"""
import pickle
import sys
from argparse import ArgumentParser, FileType
from inspect import signature
Expand Down Expand Up @@ -99,6 +98,13 @@
"--version", "-v",
action="version",
version=f"%(prog)s {m2cgen.__version__}")
parser.add_argument(
"--pickle-lib", "-pl",
type=str,
dest="lib",
help="Sets the lib used to save the model",
choices=["pickle", "joblib"],
default="pickle")


def parse_args(args):
Expand All @@ -109,7 +115,8 @@ def generate_code(args):
sys.setrecursionlimit(args.recursion_limit)

with args.infile as f:
model = pickle.load(f)
pickle_lib = __import__(args.lib)
model = pickle_lib.load(f)

exporter, supported_args = LANGUAGE_TO_EXPORTER[args.language]

Expand Down
13 changes: 11 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def _get_mock_args(
package_name=None,
class_name=None,
infile=None,
language=None
language=None,
lib="pickle"
):
return mock.MagicMock(
indent=indent,
Expand All @@ -28,7 +29,8 @@ def _get_mock_args(
class_name=class_name,
infile=infile,
language=language,
recursion_limit=cli.MAX_RECURSION_DEPTH)
recursion_limit=cli.MAX_RECURSION_DEPTH,
lib=lib)


def test_file_as_input(tmp_path):
Expand Down Expand Up @@ -122,6 +124,13 @@ def test_namespace(pickled_model):
assert "namespace Tests.ML {" in generated_code


def test_joblib_loading(pickled_model):
mock_args = _get_mock_args(infile=pickled_model, language="go", lib="joblib")
generated_code = cli.generate_code(mock_args).strip()

assert generated_code.startswith("func score(input []float64) float64 {\n")


def test_indent(pickled_model):
mock_args = _get_mock_args(infile=pickled_model, indent=0, language="c_sharp")
generated_code = cli.generate_code(mock_args).strip()
Expand Down

0 comments on commit b0f76ef

Please sign in to comment.