This repo is modified and based off the implementation: gflownet - it contains GFlowNet-related training and environment code on graphs for investigating the generalization capabilities of GFlowNets.
It is hypothesize that GFlowNets leverage the generalization potential of deep neural networks to assign probability mass to unvisited areas of the state space. This repo contains a graph generation benchmark environment (considering several rewards of varying difficulty) such that
GFlowNet, short for Generative Flow Network (sometimes also abbreviated as GFN), is a novel generative modelling framework for learning unnormalized probability mass functions over discrete spaces, particularly suited for discrete/combinatorial objects. Here, the focus is on graph generation.
The idea behind GFlowNets is to estimate flows in a (graph-theoretic) directed acyclic network. The network represents all possible ways of constructing an object, and so knowing the flow gives us a policy which we can follow to sequentially construct objects. Such a sequence of partially constructed objects is a trajectory. Perhaps confusingly, the network in a GFlowNet refers to the state space, not a neural network architecture. Here the objects we construct are themselves graphs, which are constructed node by node. To make policy predictions, we use a graph neural network, parameterizing the forward policy
If you find this code useful in your research, please cite the following paper (expand for BibTeX):
L. Atanackovic, E. Bengio. Investigating Generalization Behaviours of Generative Flow Networks, 2024.
@article{atanackovic2024,
title={Investigating Generalization Behaviours of Generative Flow Networks},
author={Atanackovic, Lazar and Bengio, Emmanuel},
journal={arXiv preprint arXiv:2402.05309},
year={2024}
}
Structure of repo:
- algo, contains GFlowNet algorithms implementations (Trajectory Balance, SubTB, Flow Matching), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories.
- data, contains dataset definitions, data loading and data sampling utilities.
- envs, contains environment classes; a graph-building environment base, and a molecular graph context class. The base environment is agnostic to what kind of graph is being made, and the context class specifies mappings from graphs to objects (e.g. molecules) and torch geometric Data.
- examples, contains simple example implementations of GFlowNet.
- models, contains model definitions.
- tasks, contains training code.
basic_graph_task.py
, graph generation environment for counting, neighbors, and cliques tasks.
- utils, contains utilities (multiprocessing, metrics, conditioning).
trainer.py
, defines a general harness for training GFlowNet models.online_trainer.py
, defines a typical online-GFN training loop.
See implementation notes for more.
First, generate a cache of all states/graphs up to 7 nodes. To do this run the following:
python basic_graph_task.py --recompute-all ./data/basic_graphs 7
To train a single model on the graph benchmark tasks, run:
python expts/task_single_run_gfn.py
To run an experiment, e.g. training the distilled flow models for task_distilled_flows.py
.
This package is installable as a PIP package, but since it depends on some torch-geometric package wheels, the --find-links
arguments must be specified as well:
pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cu117.html
Or for CPU use:
pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html
To install or depend on a specific tag, for example here v0.0.10
, use the following scheme:
pip install git+https://github.com/recursionpharma/[email protected] --find-links ...
If package dependencies seem not to work, you may need to install the exact frozen versions listed requirements/
, i.e. pip install -r requirements/main_3.9.txt
.