Skip to content

Commit

Permalink
Merge branch 'main' of github.com:salesforce/CodeTF
Browse files Browse the repository at this point in the history
  • Loading branch information
bdqnghi committed Jun 6, 2023
2 parents 2b352e1 + c794017 commit 096af34
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
<img alt="license" src="https://img.shields.io/badge/License-Apache%202.0-green.svg"/>
</a>
<a href="https://www.python.org/downloads/release/python-380/">
<img alt="license" src="https://img.shields.io/badge/python-3.8+-yellow.svg"/>
<img alt="python" src="https://img.shields.io/badge/python-3.8+-yellow.svg"/>
</a>
<a href="https://pypi.org/project/salesforce-codetf/">
<img alt="license" src="https://static.pepy.tech/badge/salesforce-codetf"/>
<img alt="downloads" src="https://static.pepy.tech/badge/salesforce-codetf"/>
</a>

<a href="https://arxiv.org/pdf/2306.00029.pdf">Technical Report</a>,
Expand Down Expand Up @@ -87,7 +87,7 @@ conda activate codetf

2. Install from [PyPI](https://pypi.org/project/salesforce-codetf/):
```bash
pip install salesforce-codetf==1.0.1
pip install salesforce-codetf==1.0.1.1
```

3. Alternatively, build CodeTF from source:
Expand Down Expand Up @@ -120,16 +120,16 @@ from codetf.models import load_model_pipeline

code_generation_model = load_model_pipeline(model_name="codet5", task="pretrained",
model_type="plus-220M", is_eval=True,
load_in_8bit=True, weight_sharding=False)
load_in_8bit=True, load_in_4bit=False, weight_sharding=False)

result = code_generation_model.predict(["def print_hello_world():"])
print(result)
```
There are a few notable arguments that need to be considered:
- ``model_name``: the name of the model, currently support ``codet5`` and ``causal-lm``.
- ``model_type``: type of model for each model name, e.g. ``base``, ``codegen-350M-mono``, ``j-6B``, etc.
- ``load_in_8bit``: inherit the ``load_in_8bit" feature from [Huggingface Quantization](https://huggingface.co/docs/transformers/main/main_classes/quantization).
- ``weight_sharding``: our advance feature that leverate [HuggingFace Sharded Checkpoint](https://huggingface.co/docs/accelerate/v0.19.0/en/package_reference/big_modeling#accelerate.load_checkpoint_and_dispatch) to split a large model in several smaller shards in different GPUs. Please consider using this if you are dealing with large models.
- ``load_in_8bit`` and ``load_in_4bit``: inherit the dynamic quantization feature from [Huggingface Quantization](https://huggingface.co/docs/transformers/main/main_classes/quantization).
- ``weight_sharding``: our advance feature that leverages [HuggingFace Sharded Checkpoint](https://huggingface.co/docs/accelerate/v0.19.0/en/package_reference/big_modeling#accelerate.load_checkpoint_and_dispatch) to split a large model in several smaller shards in different GPUs. Please consider using this if you are dealing with large models.

### Model Zoo
You might want to view all of the supported models. To do this, you can use the ``model_zoo()``:
Expand Down

0 comments on commit 096af34

Please sign in to comment.