Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added text2sql demo code #32

Open
wants to merge 23 commits into
base: steven-pr
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
132801a
added text2sql demo code
SichengStevenLi Jul 5, 2024
358d72f
modified README file
SichengStevenLi Jul 5, 2024
c955727
Update README.md
Jianhui-Li Jul 5, 2024
6849e5b
Merge pull request #1 from Jianhui-Li/Jianhui-Li-patch-1
Jianhui-Li Jul 5, 2024
654ce29
Merge pull request #1 from Jianhui-Li/main
SichengStevenLi Jul 5, 2024
3a47b97
Revert "update reademe"
SichengStevenLi Jul 5, 2024
d111f3f
Merge pull request #2 from SichengStevenLi/revert-1-main
SichengStevenLi Jul 5, 2024
61f543b
modified README file
SichengStevenLi Jul 5, 2024
cb3ef94
Merge branch 'main' of https://github.com/SichengStevenLi/llama_index…
SichengStevenLi Jul 10, 2024
3340c1f
modified README and sql file
SichengStevenLi Jul 10, 2024
5ad42c7
modified README
SichengStevenLi Jul 10, 2024
94d7c9d
added device argument and modified README accordingly
SichengStevenLi Jul 10, 2024
04a1ae5
modified wording in README
SichengStevenLi Jul 10, 2024
2c18d0b
modified README
SichengStevenLi Jul 10, 2024
ee931d4
fixed typo in text2sql.py
SichengStevenLi Jul 15, 2024
e125d7c
fixed syntax issue
SichengStevenLi Jul 30, 2024
fa11ff3
Update README.md
SichengStevenLi Jul 30, 2024
bb47ef5
Update README.md after linting
SichengStevenLi Aug 6, 2024
327d3c4
Update text2sql.py according to linter
SichengStevenLi Aug 6, 2024
1098331
Update README.md after linting
SichengStevenLi Aug 15, 2024
b081d16
Update text2sql.py according to linter
SichengStevenLi Aug 15, 2024
ab8bce9
Update README.md after linting
SichengStevenLi Aug 15, 2024
9b48958
Update text2sql.py according to linter
SichengStevenLi Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,55 @@ python more_data_type.py -m <path_to_model> -t <path_to_tokenizer> -l <low_bit_f

> Note: If you're using [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) model in this example, it is recommended to use transformers version
> <=4.34.

### Text2SQL Example
SichengStevenLi marked this conversation as resolved.
Show resolved Hide resolved

This example [text2sql](./text2sql.py) demonstrates how to use LlamaIndex with `ipex-llm` to run a text-to-SQL model on Intel hardware. This example shows how to create a database, define a schema, and run SQL queries using low-bit model optimized with `ipex-llm`.

### Setup

It requires `llama-index-embeddings-ipex-llm` package as it uses `ipex-llm` embedding.

> ```bash
> pip install llama-index-embeddings-ipex-llm
> ```

#### Runtime Configurations

For optimal performance, it is recommended to set several environment variables based on your device:

- For Windows Users with Intel Core Ultra integrated GPU
In Anaconda Prompt:

> > ```
> > set SYCL_CACHE_PERSISTENT=1
> > set BIGDL_LLM_XMX_DISABLED=1
> > ```

- For Linux Users with Intel Arc A-Series GPU:
> > ```
> > # Configure oneAPI environment variables. Required step for APT or offline installed oneAPI.
> > # Skip this step for PIP-installed oneAPI since the environment has already been configured in LD_LIBRARY_PATH.
> > source /opt/intel/oneapi/setvars.sh
> >
> > # Recommended Environment Variables for optimal performance
> > export USE_XETLA=OFF
> > export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
> > export SYCL_CACHE_PERSISTENT=1
> > ```

> **NOTE** For the first time that each model runs on Intel iGPU/Intel Arc A300-Series or Pro A60, it may take several minutes to compile.

---

### Run the Example

Then, run the example as following:

```
python text2sql.py -m <path_to_model> -d <device> -e <path_to_embedding_model> -q <query_to_LLM> -n <num_predict>
```

> Please note that in this example we'll use [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model for demonstration, as well as [bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5) for our embedding model. It requires updating transformers and tokenizers packages. But you are also welcomed to use other models.

> If you use other LLMs and encounter output issues, please try changing it.
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import NLSQLRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.llms.ipex_llm import IpexLLM
from llama_index.embeddings.ipex_llm import IpexLLMEmbedding
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, insert
import argparse


def create_database_schema():
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
return engine, city_stats_table


def define_sql_database(engine, city_stats_table):
sql_database = SQLDatabase(engine, include_tables=["city_stats"])

rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{
"city_name": "Chicago",
"population": 2679000,
"country": "United States",
},
{"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)

return sql_database


def main(args):
engine, city_stats_table = create_database_schema()

sql_database = define_sql_database(engine, city_stats_table)

model_id = args.embedding_model_path
device_map = args.device

embed_model = IpexLLMEmbedding(model_id, device=device_map)

llm = IpexLLM.from_model_id(
model_name=args.model_path,
tokenizer_name=args.model_path,
context_window=512,
max_new_tokens=args.n_predict,
generate_kwargs={"temperature": 0.7, "do_sample": False},
model_kwargs={},
device_map=device_map,
)

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
sql_database,
tables=["city_stats"],
llm=llm,
embed_model=embed_model,
return_raw=True,
)

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, llm=llm)
query_str = args.question
response = query_engine.query(query_str)
print(str(response))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LlamaIndex IpexLLM Example")
parser.add_argument(
"-m",
"--model-path",
type=str,
required=True,
help="the path to transformers model",
)
parser.add_argument(
"--device",
"-d",
type=str,
default="cpu",
choices=["cpu", "xpu"],
help="The device (Intel CPU or Intel GPU) the LLM model runs on",
)
parser.add_argument(
"-q",
"--question",
type=str,
default="Which city has the highest population?",
help="question you want to ask.",
)
parser.add_argument(
"-e",
"--embedding-model-path",
default="BAAI/bge-small-en",
help="the path to embedding model path",
)
parser.add_argument(
"-n", "--n-predict", type=int, default=32, help="max number of predict tokens"
)
args = parser.parse_args()

main(args)
Loading