tocdepth: | 1 |
---|
Similar to the PyTorch front-end, FlexFlow also supports training existing ONNX models. Since both ONNX and FlexFlow use Protocol Buffer, make sure they are linked with the Protocol Buffer of the same version.
A PyTorch model can be exported to the FlexFlow model format and saved into an external file:
import onnx import torch import torch.nn as nn from torch.onnx import TrainingMode # create a PyTorch Model class MyPyTorchModule(nn.Module): ... # export the PyTorch model to a ONNX model model = MyPyTorchModule() torch.onnx.export(model, (input), "filename", export_params=False, training=TrainingMode.TRAINING)
A FlexFlow program can directly import a previously saved ONNX model and autotune the parallelization performance for a given parallel machine:
from flexflow.torch.model import PyTorchModel #create input tensors dims_input = [ffconfig.get_batch_size(), 3, 32, 32] input_tensor = ffmodel.create_tensor(dims_input, DataType.DT_FLOAT) # create a flexflow model from the file onnx_model = ONNXModel("cifar10_cnn.onnx") output_tensor = onnx_model.apply(ffmodel, {"input.1": input_tensor}) # use the Python API to train the model ffoptimizer = SGDOptimizer(ffmodel, 0.01) ffmodel.set_sgd_optimizer(ffoptimizer) ffmodel.compile(loss_type=LossType.LOSS_SPARSE_CATEGORICAL_CROSSENTROPY, metrics=[MetricsType.METRICS_ACCURACY, MetricsType.METRICS_SPARSE_CATEGORICAL_CROSSENTROPY]) ... ffmodel.fit(x=dataloader_input, y=dataloader_label, epochs=epochs)
More FlexFlow ONNX examples are available on GitHub.