diff --git a/README.md b/README.md index 99b6f15..eb31c71 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,31 @@ You can configure the demo by specifying the following parameters: - `--num_proc`: Number of processes to run in parallel for faster execution. - `--multi_turn`: Boolean to toggle multi-turn interaction capability. +## OpenAI Compatible API Endpoint + +We provide an OpenAI-compatible API endpoint that allows you to interact with the Mixture-of-Agents (MoA) system using a familiar API format. This makes it easy to integrate MoA into existing applications that use OpenAI-style APIs. + +To start the API server, use the `api.py` script. It supports the same parameters as `bot.py`, with an additional `--port` parameter to specify the port number for the server. + +### Usage + +To run the API server, use the following command: + +```bash +python api.py --model --reference-models --reference-models ... --temperature --max-tokens --rounds --port +``` + +For example: + +```bash +python api.py --model "Qwen/Qwen2-72B-Instruct" --reference-models "Qwen/Qwen2-72B-Instruct" "Qwen/Qwen1.5-72B-Chat" "mistralai/Mixtral-8x22B-Instruct-v0.1" "databricks/dbrx-instruct" --temperature 0.7 --max-tokens 512 --rounds 1 --port 5001 +``` + +This will start an OpenAI-compatible API server on `http://localhost:5001`. You can then use this endpoint in your applications, just as you would use the OpenAI API. + +Note: This is a work-in-progress. It does not include all features of the official OpenAI API, and it doesn't stream responses as they come in. Might be unstable. + + ## Evaluation We provide scripts to quickly reproduce some of the results presented in our paper diff --git a/api.py b/api.py new file mode 100644 index 0000000..f6b6bff --- /dev/null +++ b/api.py @@ -0,0 +1,183 @@ +import typer +from flask import Flask, request, jsonify, Response, stream_with_context +import json +from functools import partial +import datasets +from utils import generate_together_stream, generate_with_references, DEBUG +from loguru import logger +from datasets.utils.logging import disable_progress_bar + +disable_progress_bar() + +app = Flask(__name__) + +default_model = None +default_reference_models = [ + "Qwen/Qwen2-72B-Instruct", + "Qwen/Qwen1.5-72B-Chat", + "mistralai/Mixtral-8x22B-Instruct-v0.1", + "databricks/dbrx-instruct", +] +_temperature = 0.7 +_max_tokens = 512 +_rounds = 1 + +def process_fn(item, temperature=0.7, max_tokens=2048): + references = item.get("references", []) + model = item["model"] + messages = item["instruction"] + + output = generate_with_references( + model=model, + messages=messages, + references=references, + temperature=_temperature, + max_tokens=_max_tokens, + ) + if DEBUG: + logger.info( + f"model: {model}, instruction: {item['instruction']}, output: {output[:20]}" + ) + + return {"output": output} + +@app.route('/v1/chat/completions', methods=['POST']) +def chat_completions(): + data = request.json + messages = data.get('messages', []) + stream = data.get('stream', False) # Check if the client requested streaming + temperature = data.get('temperature', _temperature) + max_tokens = data.get('max_tokens', _max_tokens) + + # Prepare data for processing + data = { + "instruction": [messages] * len(default_reference_models), + "references": [""] * len(default_reference_models), + "model": [m for m in default_reference_models], + } + + eval_set = datasets.Dataset.from_dict(data) + + # Process with reference models + eval_set = eval_set.map( + partial( + process_fn, + temperature=temperature, + max_tokens=max_tokens, + ), + batched=False, + num_proc=len(default_reference_models), + ) + references = [item["output"] for item in eval_set] + + # Generate final output + output = generate_with_references( + model=default_model, + temperature=temperature, + max_tokens=max_tokens, + messages=messages, + references=references, + generate_fn=generate_together_stream, + ) + + # Collect output + all_output = "" + for chunk in output: + out = chunk.choices[0].delta.content + if out is not None: + # print(out) + all_output += out + + # Prepare response + print (all_output) + response = { + "id": "chatcmpl-123", # TODO + "object": "chat.completion", + "created": 1720384636, # TODO + "model": default_model, + "usage": { + "prompt_tokens": 42, # TODO + "completion_tokens": len(all_output.split()), # Rough estimate + "total_tokens": 42 + len(all_output.split()), # Rough estimate + }, + "choices": [ + { + "message": { + "role": "assistant", + "content": all_output, + }, + "finish_reason": "stop", + "index": 0, + } + ], + } + + if DEBUG: + print(json.dumps(response, indent=2)) + + def generate(): + if stream: + # Simulate streaming by yielding chunks + #chunks = [all_output[i:i+5] for i in range(0, len(all_output), 5)] # Split into 5-character chunks + chunks = [all_output] + for chunk in chunks: + chunk_response = { + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1720384636, + "model": default_model, + "choices": [ + { + "delta": { + "content": chunk, + }, + "index": 0, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk_response)}\n\n" + + # Send the final chunk with finish_reason + final_chunk = { + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1720384636, + "model": default_model, + "choices": [ + { + "delta": {}, + "index": 0, + "finish_reason": "stop", + } + ], + } + yield f"data: {json.dumps(final_chunk)}\n\n" + yield "data: [DONE]\n\n" + else: + # Non-streaming response + yield json.dumps(response) + + if stream: + return Response(stream_with_context(generate()), content_type='text/event-stream') + else: + return jsonify(response) + +def main( + model: str = "Qwen/Qwen2-72B-Instruct", + reference_models: list[str] = default_reference_models, + temperature: float = 0.7, + max_tokens: int = 512, + rounds: int = 1, + port: int = 5001, +): + global default_model, default_reference_models, _temperature, _max_tokens, _rounds + default_model = model + default_reference_models = reference_models + _temperature = temperature + _max_tokens = max_tokens + _rounds = rounds + app.run(port=port) + +if __name__ == "__main__": + typer.run(main) diff --git a/bot.py b/bot.py index 89f8f19..fd91825 100644 --- a/bot.py +++ b/bot.py @@ -118,7 +118,7 @@ def main( model = Prompt.ask( "\n1. What main model do you want to use?", - default="Qwen/Qwen2-72B-Instruct", + default=model, ) console.print(f"Selected {model}.", style="yellow italic") temperature = float( @@ -199,8 +199,9 @@ def main( for chunk in output: out = chunk.choices[0].delta.content - console.print(out, end="") - all_output += out + if out is not None: + console.print(out, end="") + all_output += out print() if DEBUG: diff --git a/requirements.txt b/requirements.txt index bf9f390..ee5ca63 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ loguru datasets typer rich +cffi +flask