forked from h2oai/driverlessai-recipes
-
Notifications
You must be signed in to change notification settings - Fork 1
/
round_transformer.py
57 lines (47 loc) · 2.16 KB
/
round_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Rounds numbers to 1, 2 or 3 decimals"""
from h2oaicore.systemutils import dtype_global
from h2oaicore.transformer_utils import CustomTransformer
import datatable as dt
import numpy as np
class MyRoundTransformer(CustomTransformer):
_unsupervised = True
_testing_can_skip_failure = False # ensure tested as if shouldn't fail
@staticmethod
def get_parameter_choices():
return {"decimals": [1, 2, 3]}
@property
def display_name(self):
return "MyRound%dDecimals" % self.decimals
def __init__(self, decimals, **kwargs):
super().__init__(**kwargs)
self.decimals = decimals
def fit_transform(self, X: dt.Frame, y: np.array = None):
return self.transform(X)
def transform(self, X: dt.Frame):
return np.round(X.to_numpy(), decimals=self.decimals)
_mojo = False # custom op not actually implemented, below is example how one would do it
from h2oaicore.mojo import MojoWriter, MojoFrame
def to_mojo(self, mojo: MojoWriter, iframe: MojoFrame, group_uuid=None, group_name=None):
from h2oaicore.mojo import MojoColumn, MojoFrame, MojoType
import uuid
group_uuid = str(uuid.uuid4())
group_name = "RoundTransformer"
kws = dict()
kws["op_name"] = "round"
custom_param = dict()
custom_param["decimals"] = (MojoType.INT32, self.decimals)
kws["op_params"] = custom_param
from h2oaicore.mojo_transformers import MjT_CustomOp
from h2oaicore.mojo_transformers_utils import AsType
xnew = iframe[self.input_feature_names]
oframe = MojoFrame()
for col in xnew:
ocol = MojoColumn(name=col.name, dtype=col.type)
ocol_frame = MojoFrame(columns=[ocol])
mojo += MjT_CustomOp(iframe=MojoFrame(columns=[col]), oframe=ocol_frame,
group_uuid=group_uuid, group_name=group_name, **kws)
oframe += ocol
oframe = AsType(dtype_global()).write_to_mojo(mojo, oframe,
group_uuid=group_uuid,
group_name=group_name)
return oframe