You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary
Sometimes features passed to TF Serving need to be wrapped in brackets to match the dimension, and the correct dimensions are required.
However, in some cases if the incorrect dimensions are passed TF Serving crashes, when it would be expected to handle the error
In particular:
Calling RandomForestModel with numeric features and brackets causes TF Serving to fail (core dumped, Check failed: NDIMS == dims() (2 vs. 3))
Similarly calling RandomForestModel with a categorical features and allowing the model to 1-hot encode, use of brackets causes TF Serving to fail (core dumped, Check failed: NDIMS == dims() (2 vs. 3))
However, note that sometimes brackets are necessary, e.g, when using a Tensorflow model
Detail on the cases below:
(1) When TF serving is used with a Tensorflow model with numerical and categorical features, using tensorflow.feature_column to dummify the categorical features
brackets can be used around the floats or not
but brackets must be used around the categorical feature or an error is returned
(2) When TF serving is used with a TFDF Random Forest Model with only numeric features
calling without brackets works
calling with brackets causes Tensorflow Serving to crash (core dumped) with a "Check failed: NDIMS == dims() (2 vs. 3)" error
(3) When TF serving is used with a TFDF Random Forest Model with numeric features and categorical features, and a preprocessing model
works with brackets or without on any of the floats or categorical
(4) When TF serving is used with a TFDF Random Forest Model with numeric features and categorical features,
and the RandomForestModel categorical_set_split_max_num_items and categorical_set_split_min_item_frequency are used to dummify the categorical field
without brackets: works
with brackets: causes Tensorflow Serving to crash (core dumped) with a "Check failed: NDIMS == dims() (2 vs. 3)" error
"""
Run model in TFServing
docker run -p 8501:8501 --mount type=bind,source=/home/username/model_assessment/test_model_rf1,target=/models/test_model_rf1 -e MODEL_NAME=test_model_rf1 -t tensorflow/serving
"""
# Without brackets, floats are fine
test_data = {"signature_name": "serving_default",
"instances": [{"x1": 300.0,
"x2": 220.0,
}]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf1:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
# With brackets, floats cause tf serving to fail
test_data = {"signature_name": "serving_default",
"instances": [{"x1": [300.0],
"x2": [220.0],
}]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf1:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
"""
Causes
2023-03-21 10:31:24.921448: F external/org_tensorflow/tensorflow/core/framework/tensor_shape.cc:45] Check failed: NDIMS == dims() (2 vs. 3)Asking for tensor of 2 dimensions from a tensor of 3 dimensions
/usr/bin/tf_serving_entrypoint.sh: line 3: 7 Aborted (core dumped) tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} "$@"
"""
"""(3) Now make a Random Forest with feature transformation on cat feature"""
feature_layer_input = {}
feature_columns = []
for feature in ['x1', 'x2']:
feature_columns.append(feature_column.numeric_column(feature))
feature_layer_input[feature] = Input(shape=(1,), name=feature)
for feature in ['x3']:
feature_column_emb = feature_column.categorical_column_with_hash_bucket(
feature, hash_bucket_size=10
)
feature_columns.append(feature_column.indicator_column(feature_column_emb))
feature_layer_input[feature] = Input(shape=(1,), dtype=string, name=feature)
feature_layer = layers.DenseFeatures(feature_columns)(feature_layer_input)
feat_x3_model = Model(feature_layer_input, feature_layer)
rf = tfdf.keras.RandomForestModel(
task=Task.REGRESSION,
preprocessing=feat_x3_model,
num_trees=100,
max_depth=20,
min_examples=5,
num_candidate_attributes_ratio=0.45,
split_axis="AXIS_ALIGNED", #"SPARSE_OBLIQUE",
growing_strategy="LOCAL", #"BEST_FIRST_GLOBAL",
)
features = ['x1', 'x2', 'x3']
train_ds = df_to_dataset(train[features], labels=train['y'], batch_size=1000)
valid_ds = df_to_dataset(valid[features], labels=valid['y'], batch_size=1000)
test_ds = df_to_dataset(test[features], labels=test['y'], batch_size=1000)
rf.fit(train_ds)
print(r2_score(valid['y'], rf.predict(valid_ds)))
rf.save('/home/username/model_assessment/test_model_rf2/1/')
"""
Run in TF Serving
docker run -p 8501:8501 --mount type=bind,source=/home/username/model_assessment/test_model_rf2,target=/models/test_model_rf2 -e MODEL_NAME=test_model_rf2 -t tensorflow/serving
"""
# Without brackets. Fine
test_data = {"signature_name": "serving_default",
"instances": [{"x1": 300.0,
"x2": 220.0,
"x3": "cat11"
},
{"x1": 400.0,
"x2": 320.0,
"x3": "cat11"
}]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf2:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
# With brackets. Fine
test_data = {"signature_name": "serving_default",
"instances": [{"x1": [300.0],
"x2": [220.0],
"x3": ["cat11"]
}]
}
headers = {"content-type": "application/json"}
response = requests.post('http://localhost:8501/v1/models/test_model_rf2:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
# Brackets on floats, none on cat. Fine
test_data = {"signature_name": "serving_default",
"instances": [{"x1": [[300.0]],
"x2": [[[[220.0]]]],
"x3": "cat11"
}]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf2:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
# Brackets on cat, none on floats. Fine
test_data = {"signature_name": "serving_default",
"instances": [{"x1": 300.0,
"x2": 220.0,
"x3": ["cat11"]
}]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf2:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
#(4) Build a model with no preprocessing model, categorical dummifying handled by the model
rf = tfdf.keras.RandomForestModel(
task=Task.REGRESSION,
num_trees=100,
max_depth=20,
min_examples=5,
num_candidate_attributes_ratio=0.45,
split_axis="AXIS_ALIGNED", #"SPARSE_OBLIQUE",
growing_strategy="LOCAL", #"BEST_FIRST_GLOBAL",
categorical_set_split_max_num_items=100,
categorical_set_split_min_item_frequency=1,
)
features = ['x1', 'x2', 'x3']
train_ds = df_to_dataset(train[features], labels=train['y'], batch_size=1000)
valid_ds = df_to_dataset(valid[features], labels=valid['y'], batch_size=1000)
test_ds = df_to_dataset(test[features], labels=test['y'], batch_size=1000)
rf.fit(train_ds)
print(r2_score(valid['y'], rf.predict(valid_ds)))
rf.save('/home/username/username/model_assessment/test_model_rf3/1/')
"""
Run in TF Serving
docker run -p 8501:8501 --mount type=bind,source=/home/username/username/model_assessment/test_model_rf3,target=/models/test_model_rf3 -e MODEL_NAME=test_model_rf3 -t tensorflow/serving
"""
# No brackets works
test_data = {"signature_name": "serving_default",
"instances": [{"x1": 300.0,
"x2": 220.0,
"x3": "cat11"
},
{"x1": 330.0,
"x2": 230.0,
"x3": "cat13"
}]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf3:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
# Brackets on any (float or categorical) causes TF serving to fail
test_data = {"signature_name": "serving_default",
"instances": [{"x1": [300.0],
"x2": [220.0],
"x3": ["cat11"]
}]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf3:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
"""
Causes
2023-03-21 10:39:01.342750: F external/org_tensorflow/tensorflow/core/framework/tensor_shape.cc:45] Check failed: NDIMS == dims() (2 vs. 3)Asking for tensor of 2 dimensions from a tensor of 3 dimensions
/usr/bin/tf_serving_entrypoint.sh: line 3: 7 Aborted (core dumped) tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} "$@"
"""
The text was updated successfully, but these errors were encountered:
Summary
Sometimes features passed to TF Serving need to be wrapped in brackets to match the dimension, and the correct dimensions are required.
However, in some cases if the incorrect dimensions are passed TF Serving crashes, when it would be expected to handle the error
In particular:
Detail on the cases below:
(1) When TF serving is used with a Tensorflow model with numerical and categorical features, using tensorflow.feature_column to dummify the categorical features
(2) When TF serving is used with a TFDF Random Forest Model with only numeric features
(3) When TF serving is used with a TFDF Random Forest Model with numeric features and categorical features, and a preprocessing model
(4) When TF serving is used with a TFDF Random Forest Model with numeric features and categorical features,
and the RandomForestModel categorical_set_split_max_num_items and categorical_set_split_min_item_frequency are used to dummify the categorical field
Related issue
tensorflow/tensorflow#9505
System details
Python 3.10
Tensorflow version 2.11.1
Tensorflow Decision Forest 1.2.0
tensorflow/serving latest
Code to recreate
"""
Now run TF Serving
docker run -p 8501:8501 --mount type=bind,source=/home/username/model_assessment/test_model,target=/models/test_model -e MODEL_NAME=test_model -t tensorflow/serving
"""
"""
Run model in TFServing
docker run -p 8501:8501 --mount type=bind,source=/home/username/model_assessment/test_model_rf1,target=/models/test_model_rf1 -e MODEL_NAME=test_model_rf1 -t tensorflow/serving
"""
"""
Run in TF Serving
docker run -p 8501:8501 --mount type=bind,source=/home/username/model_assessment/test_model_rf2,target=/models/test_model_rf2 -e MODEL_NAME=test_model_rf2 -t tensorflow/serving
"""
"""
Run in TF Serving
docker run -p 8501:8501 --mount type=bind,source=/home/username/username/model_assessment/test_model_rf3,target=/models/test_model_rf3 -e MODEL_NAME=test_model_rf3 -t tensorflow/serving
"""
The text was updated successfully, but these errors were encountered: