This repository provides a Flax implementation of factorization machine models and common datasets in CTR prediction. The code on this repository was converted from a pytorch implementation of factorization machine models to a flax implementation code.
-
no_batch_norm_dropout_exists_training.ipynb
- batch_normalization 을 하지 않고, dropout을 해야하는 모델에 대해서 training할 때 사용하는 주피터 노트북입니다.
- 모델
- AttentionalFactorizationMachineModelFlax
-
batch_norm_dropout_exist_training.ipynb
- batch_normalization과 dropout을 포함한 모델에 대해서 training할 때 사용하는 주피터 노트북입니다.
- 모델
- WideAndDeepModelFlax
- FactorizationSupportedNeuralNetworkModelFlax
- NeuralFactorizationMachineModelFlax
- NeuralCollaborativeFilteringFlax
- FieldAwareNeuralFactorizationMachineModelFlax
- DeepFactorizationMachineModelFlax
- ExtremeDeepFactorizationMachineModelFlax
-
no_batch_norm_dropout_training.ipynb
- batch_normaliation과 dropout이 포함되지 않은 모델에 대해서 training할 때 사용하는 주피터 노트북입니다.
- 모델
- LogisticRegressionModelFlax
- FactorizationMachineModelFlax
- FieldAwareFactorizationMachineModelFlax
-
compare_pytorch_flax_train_speed.ipynb
- 동일한 데이터셋(MovieLens20MDataset)에 대해 pytorch로 구현한 FactorizationMachineModel과 Flax로 구현한 FactorizationMachineModel에 대해 각각 모델 트레이닝을 하고 트레이닝 속도 및 loss function 값의 수렴도를 비교한 주피터 노트북 파일입니다.
-
compare_pytorch_flax_model_architecture.ipynb
- 동일한 데이터셋(MovieLens20MDataset)에 대해 pytorch로 구현한 FactorizationMachineModel과 Flax로 구현한 FactorizationMachineModel를 각각 onnx 파일과 tflite 파일로 export하고, netron 라이브러리를 이용하여 모델 구조를 시각화하여 비교해봅니다.
- scalene(a high-performance CPU, GPU and memory profiler for Python)를 사용하여 모델에 대해 profiling한 결과들을 html파일 형태로 업로드하였습니다.
- htmlviewer을 통해 profile_results로 나온 결과를 볼 수 있습니다.
- 모델별 특징과 상관없이 통일된 training 코드 작성
- Inference 코드 작성
- html 파일을 더 쉽게 볼 수 있는 법에 대해 고민해보기
- Dockerfile 설치 과정 효율화
- 현재는 설치해야하는 용량이 매우 크고 오래 걸리는 편
https://github.com/rixwew/pytorch-fm
MIT