Skip to content

Commit

Permalink
refact: formatando arquivos python
Browse files Browse the repository at this point in the history
  • Loading branch information
heltonricardo committed Dec 13, 2023
1 parent b1acb78 commit f614208
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 37 deletions.
24 changes: 14 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def load_model():
"""
Loads a pre-trained model from an MLflow server.
This function connects to an MLflow server using the provided tracking URI, username,
and password.
It retrieves the latest version of the 'fetal_health' model registered on the server.
The function then loads the model using the specified run ID and returns the loaded model.
This function connects to an MLflow server using the provided tracking URI,
username, and password.
It retrieves the latest version of the 'fetal_health' model registered on
the server.
The function then loads the model using the specified run ID and returns
the loaded model.
Returns:
loaded_model: The loaded pre-trained model.
Expand All @@ -37,7 +39,8 @@ def load_model():
None
"""
print("reading model...")
MLFLOW_TRACKING_URI = "https://dagshub.com/renansantosmendes/mlops-ead.mlflow"
MLFLOW_TRACKING_URI = \
"https://dagshub.com/renansantosmendes/mlops-ead.mlflow"
MLFLOW_TRACKING_USERNAME = "renansantosmendes"
MLFLOW_TRACKING_PASSWORD = "b63baf8c662a23fa00deb74ba86600278769e5dd"
os.environ["MLFLOW_TRACKING_USERNAME"] = MLFLOW_TRACKING_USERNAME
Expand All @@ -59,8 +62,8 @@ def load_model():
@app.on_event(event_type="startup")
def startup_event():
"""
A function that is called when the application starts up. It loads a model into the
global variable `loaded_model`.
A function that is called when the application starts up. It loads a model
into the global variable `loaded_model`.
Parameters:
None
Expand All @@ -78,8 +81,8 @@ def api_health():
A function that represents the health endpoint of the API.
Returns:
dict: A dictionary containing the status of the API, with the key "status" and
the value "healthy".
dict: A dictionary containing the status of the API, with the key
"status" and the value "healthy".
"""
return {"status": "healthy"}

Expand All @@ -90,7 +93,8 @@ def predict(request: FetalHealthData):
Predicts the fetal health based on the given request data.
Args:
request (FetalHealthData): The request data containing the fetal health parameters.
request (FetalHealthData): The request data containing the fetal
health parameters.
Returns:
dict: A dictionary containing the prediction of the fetal health.
Expand Down
20 changes: 10 additions & 10 deletions test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def sample_data():
A fixture function that returns a sample dataset.
Returns:
pandas.DataFrame: A DataFrame containing sample data with three columns: 'feature1',
'feature2', and 'fetal_health'.
pandas.DataFrame: A DataFrame containing sample data with three
columns: 'feature1', 'feature2', and 'fetal_health'.
"""
data = pd.DataFrame(
{
Expand All @@ -26,8 +26,8 @@ def sample_data():

def test_read_data():
"""
This function tests the `read_data` function. It checks whether the returned data is not
empty for both features (X) and labels (y).
This function tests the `read_data` function. It checks whether the
returned data is not empty for both features (X) and labels (y).
Parameters:
None
Expand All @@ -43,8 +43,8 @@ def test_read_data():

def test_create_model():
"""
Generate the function comment for the given function body in a markdown code block with
the correct language syntax.
Generate the function comment for the given function body in a markdown
code block with the correct language syntax.
"""
X, _ = read_data()
model = create_model(X)
Expand All @@ -56,12 +56,12 @@ def test_create_model():

def test_train_model(sample_data):
"""
Generate a function comment for the given function body in a markdown code block with
the correct language syntax.
Generate a function comment for the given function body in a markdown code
block with the correct language syntax.
Parameters:
sample_data (pandas.DataFrame): The input data containing features and target
variable.
sample_data (pandas.DataFrame): The input data containing features and
target variable.
Returns:
None
Expand Down
45 changes: 28 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def reset_seeds():
"""
Reset the seeds for random number generators.
This function sets the seeds for the `os`, `tf.random`, `np.random`, and `random`
modules to ensure reproducibility in random number generations.
This function sets the seeds for the `os`, `tf.random`, `np.random`, and
`random` modules to ensure reproducibility in random number generations.
Parameters:
None
Expand All @@ -32,14 +32,16 @@ def reset_seeds():

def read_data():
"""
Reads the data from a CSV file and returns the feature matrix X and target vector y.
Reads the data from a CSV file and returns the feature matrix X and target
vector y.
Returns:
X (pandas.DataFrame): The feature matrix of shape (n_samples, n_features).
X (pandas.DataFrame): The feature matrix of shape (n_samp, n_feat).
y (pandas.Series): The target vector of shape (n_samples,).
"""
data = pd.read_csv(
"https://raw.githubusercontent.com/heltonricardo/fetal-health-classifier/main/fetal_health_reduced.csv"
data = pd.read_csv("""
https://raw.githubusercontent.com/heltonricardo/
fetal-health-classifier/main/fetal_health_reduced.csv"""
)
X = data.drop(["fetal_health"], axis=1)
y = data["fetal_health"]
Expand Down Expand Up @@ -77,11 +79,12 @@ def process_data(X, y):

def create_model(X):
"""
Creates a neural network model for classification based on the given input data.
Creates a neural network model for classification based on the given input
data.
Parameters:
X (numpy.ndarray): The input data array. It should have a shape of (num_samples,
num_features).
X (numpy.ndarray): The input data array. It should have a shape of
(num_samples, num_features).
Returns:
tensorflow.keras.models.Sequential: The created neural network model.
Expand All @@ -94,7 +97,9 @@ def create_model(X):
model.add(Dense(3, activation="softmax"))

model.compile(
loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
return model

Expand All @@ -104,14 +109,17 @@ def config_mlflow():
Configures the MLflow settings for tracking experiments.
Sets the MLFLOW_TRACKING_USERNAME and MLFLOW_TRACKING_PASSWORD environment
variables to provide authentication for accessing the MLflow tracking server.
variables to provide authentication for accessing the MLflow tracking
server.
Sets the MLflow tracking URI to 'https://dagshub.com/renansantosmendes/mlops-ead.mlflow'
Sets the MLflow tracking URI to
'https://dagshub.com/renansantosmendes/mlops-ead.mlflow'
to specify the location where the experiment data will be logged.
Enables autologging of TensorFlow models by calling `mlflow.tensorflow.autolog()`.
This will automatically log the TensorFlow models, input examples, and model signatures
during training.
Enables autologging of TensorFlow models by calling
`mlflow.tensorflow.autolog()`.
This will automatically log the TensorFlow models, input examples, and
model signatures during training.
Parameters:
None
Expand All @@ -120,8 +128,11 @@ def config_mlflow():
None
"""
os.environ["MLFLOW_TRACKING_USERNAME"] = "renansantosmendes"
os.environ["MLFLOW_TRACKING_PASSWORD"] = "6d730ef4a90b1caf28fbb01e5748f0874fda6077"
mlflow.set_tracking_uri("https://dagshub.com/renansantosmendes/mlops-ead.mlflow")
os.environ["MLFLOW_TRACKING_PASSWORD"] = \
"6d730ef4a90b1caf28fbb01e5748f0874fda6077"
mlflow.set_tracking_uri(
"https://dagshub.com/renansantosmendes/mlops-ead.mlflow"
)

mlflow.tensorflow.autolog(
log_models=True, log_input_examples=True, log_model_signatures=True
Expand Down

0 comments on commit f614208

Please sign in to comment.