Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mishig25 committed Oct 8, 2024
1 parent ac82746 commit 4cd1f83
Showing 1 changed file with 7 additions and 35 deletions.
42 changes: 7 additions & 35 deletions .github/workflows/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ jobs:
- name: Execute Python script
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} # Make sure to set this secret in your repository settings
run: python -c '
run: |
python -c '
import os
import ast
import json
import requests
from huggingface_hub import HfApi


def extract_models_sub_dict(parsed_code, sub_dict_name):
class MODELS_SUB_LIST_VISITOR(ast.NodeVisitor):
def __init__(self):
Expand All @@ -50,15 +50,12 @@ def extract_models_sub_dict(parsed_code, sub_dict_name):
visitor.visit(parsed_code)
return visitor.value


def extract_models_dict(source_code):
parsed_code = ast.parse(source_code)

class MODELS_LIST_VISITOR(ast.NodeVisitor):
def __init__(self):
self.key = "_MODELS"
self.value = {}

def visit_Assign(self, node):
for target in node.targets:
if not isinstance(target, ast.Name):
Expand All @@ -67,42 +64,17 @@ def extract_models_dict(source_code):
for value in node.value.values:
dict = extract_models_sub_dict(parsed_code, value.id)
self.value.update(dict)

visitor = MODELS_LIST_VISITOR()
visitor.visit(parsed_code)
return visitor.value


# Fetch the content of the file
url = "https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/vllm/model_executor/models/registry.py"
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
source_code = response.text

if __name__ == '__main__':
# extract models dict that consists of sub dicts
# _MODELS = {
# **_TEXT_GENERATION_MODELS,
# **_EMBEDDING_MODELS,
# **_MULTIMODAL_MODELS,
# **_SPECULATIVE_DECODING_MODELS,
# }
# _TEXT_GENERATION_MODELS = {
# "AquilaModel": ("llama", "LlamaForCausalLM"),
# "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
# "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
# ...
models_dict = extract_models_dict(source_code)
architectures = [item for tup in models_dict.values() for item in tup]
architectures_json_str = json.dumps(architectures, indent=4)
json_bytes = architectures_json_str.encode('utf-8')
print(architectures_json_str)

# api = HfApi(token=os.environ["HF_TOKEN"])
# api.upload_file(
# path_or_fileobj=json_bytes,
# path_in_repo="archtiectures.json",
# repo_id="mishig/test-vllm",
# repo_type="dataset",
# )'

models_dict = extract_models_dict(source_code)
architectures = [item for tup in models_dict.values() for item in tup]
architectures_json_str = json.dumps(architectures, indent=4)
json_bytes = architectures_json_str.encode("utf-8")
print(architectures_json_str)'

0 comments on commit 4cd1f83

Please sign in to comment.