diff --git a/Makefile b/Makefile index 9c549527..cc915cd3 100644 --- a/Makefile +++ b/Makefile @@ -25,13 +25,14 @@ install-tools: install-woke ## Install required utilities/tools # install all dependencies, including devel ones pdm install --dev # check that correct mypy version is installed - mypy --version + # mypy --version + pdm run mypy --version # check that correct Black version is installed - black --version + pdm run black --version # check that correct Ruff version is installed - ruff --version + pdm run ruff --version # check that correct Pydocstyle version is installed - pydocstyle --version + pdm run pydocstyle --version install-woke: ## Install woke, required for Inclusive Naming scan @@ -93,18 +94,18 @@ integration-tests-coverage-report: test-integration ## Export integration test c coverage html --data-file="${ARTIFACT_DIR}/.coverage.integration" -d htmlcov-integration check-types: ## Checks type hints in sources - mypy --explicit-package-bases --disallow-untyped-calls --disallow-untyped-defs --disallow-incomplete-defs ols/ + pdm run mypy --explicit-package-bases --disallow-untyped-calls --disallow-untyped-defs --disallow-incomplete-defs ols/ security-check: ## Check the project for security issues bandit -c pyproject.toml -r . format: ## Format the code into unified format - black . - ruff check . --fix --per-file-ignores=tests/*:S101 --per-file-ignores=scripts/*:S101 + pdm run black . + pdm run ruff check . --fix --per-file-ignores=tests/*:S101 --per-file-ignores=scripts/*:S101 verify: install-woke install-deps-test ## Verify the code using various linters - black . --check - ruff check . --per-file-ignores=tests/*:S101 --per-file-ignores=scripts/*:S101 + pdm run black . --check + pdm run ruff check . --per-file-ignores=tests/*:S101 --per-file-ignores=scripts/*:S101 ./woke . --exit-1-on-failure schema: ## Generate OpenAPI schema file diff --git a/README.md b/README.md index 32386d16..12a20af0 100644 --- a/README.md +++ b/README.md @@ -433,10 +433,17 @@ Depends on configuration, but usually it is not needed to generate or use API ke > This action may be required for self-hosted LLMs. -## 8. Registering a new LLM provider +## 8. (Optional) Configure the number of workers + By default the number of workers is set to 1, you can increase the number of workers to scale up the REST api by modifying the max_workers config option in olsconfig.yaml. + ```yaml + ols_config: + max_workers: 4 + ``` + +## 9. Registering a new LLM provider Please look [here](https://github.com/openshift/lightspeed-service/blob/main/CONTRIBUTING.md#adding-a-new-providermodel) for more info. -## 9. Fine tuning +## 10. Fine tuning The service uses the, so called, system prompt to put the question into context before the question is sent to the selected LLM. The default system prompt is fine tuned for questions about OpenShift and Kubernetes. It is possible to use a different system prompt via the configuration option `system_prompt_path` in the `ols_config` section. That option must contain the path to the text file with the actual system prompt (can contain multiple lines). An example of such configuration: ```yaml diff --git a/docs/config.puml b/docs/config.puml index 26ce52b4..ba60ef01 100644 --- a/docs/config.puml +++ b/docs/config.puml @@ -70,6 +70,7 @@ class "ModelParameters" as ols.app.models.config.ModelParameters { class "OLSConfig" as ols.app.models.config.OLSConfig { authentication_config conversation_cache : Optional[ConversationCacheConfig] + max_workers: Optional[int] default_model : Optional[str] default_provider : Optional[str] extra_ca : list[FilePath] diff --git a/examples/olsconfig-local-ollama.yaml b/examples/olsconfig-local-ollama.yaml index 0afae620..8f9104e2 100644 --- a/examples/olsconfig-local-ollama.yaml +++ b/examples/olsconfig-local-ollama.yaml @@ -17,6 +17,7 @@ llm_providers: models: - name: 'llama3.1:latest' ols_config: + # max_workers: 1 reference_content: # product_docs_index_path: "./vector_db/ocp_product_docs/4.15" # product_docs_index_id: ocp-product-docs-4_15 diff --git a/examples/olsconfig.yaml b/examples/olsconfig.yaml index cfac557a..cfcab87e 100644 --- a/examples/olsconfig.yaml +++ b/examples/olsconfig.yaml @@ -8,6 +8,12 @@ llm_providers: context_window_size: 8000 parameters: max_tokens_for_response: 500 + tlsSecurityProfile: + type: Custom + ciphers: + - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + minTLSVersion: VersionTLS13 - name: my_openai type: openai url: "https://api.openai.com/v1" @@ -19,6 +25,7 @@ llm_providers: type: azure_openai url: "https://myendpoint.openai.azure.com/" credentials_path: azure_openai_api_key.txt + api_version: "2024-02-15-preview" deployment_name: my_azure_openai_deployment_name models: - name: gpt-3.5-turbo @@ -48,6 +55,7 @@ llm_providers: models: - name: merlinite-7b-lab-Q4_K_M ols_config: + # max_workers: 1 reference_content: # product_docs_index_path: "./vector_db/ocp_product_docs/4.15" # product_docs_index_id: ocp-product-docs-4_15 @@ -83,6 +91,12 @@ ols_config: tls_certificate_path: /app-root/certs/certificate.crt tls_key_path: /app-root/certs/private.key tls_key_password_path: /app-root/certs/password.txt + tlsSecurityProfile: + type: Custom + ciphers: + - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + minTLSVersion: VersionTLS13 dev_config: # config options specific to dev environment - launching OLS in local enable_dev_ui: true diff --git a/examples/openshift-lightspeed-tls.yaml b/examples/openshift-lightspeed-tls.yaml index 2b5016b6..764e31bc 100644 --- a/examples/openshift-lightspeed-tls.yaml +++ b/examples/openshift-lightspeed-tls.yaml @@ -209,6 +209,7 @@ data: - name: gpt-4-1106-preview - name: gpt-3.5-turbo ols_config: + max_workers: 1 reference_content: product_docs_index_path: "./vector_db/ocp_product_docs/4.15" product_docs_index_id: ocp-product-docs-4_15 diff --git a/ols/app/models/config.py b/ols/app/models/config.py index 5c0f6e55..c774a79f 100644 --- a/ols/app/models/config.py +++ b/ols/app/models/config.py @@ -218,6 +218,51 @@ class AuthenticationConfig(BaseModel): k8s_ca_cert_path: Optional[FilePath] = None +class TLSSecurityProfile(BaseModel): + """TLS security profile structure.""" + + profile_type: Optional[str] = None + min_tls_version: Optional[str] = None + ciphers: Optional[list[str]] = None + + def __init__(self, data: Optional[dict] = None) -> None: + """Initialize configuration and perform basic validation.""" + super().__init__() + if data is not None: + self.profile_type = data.get("type") + self.min_tls_version = data.get("minTLSVersion") + self.ciphers = data.get("ciphers") + + def validate_yaml(self) -> None: + """Validate structure content.""" + # check the TLS profile type + if self.profile_type is not None: + try: + tls.TLSProfiles(self.profile_type) + except ValueError: + raise InvalidConfigurationError( + f"Invalid TLS profile type '{self.profile_type}'" + ) + # check the TLS protocol version + if self.min_tls_version is not None: + try: + tls.TLSProtocolVersion(self.min_tls_version) + except ValueError: + raise InvalidConfigurationError( + f"Invalid minimal TLS version '{self.min_tls_version}'" + ) + # check ciphers + if self.ciphers is not None: + # just perform the check for non-custom TLS profile type + if self.profile_type is not None and self.profile_type != "Custom": + supported_ciphers = tls.TLS_CIPHERS[tls.TLSProfiles(self.profile_type)] + for cipher in self.ciphers: + if cipher not in supported_ciphers: + raise InvalidConfigurationError( + f"Unsupported cipher '{cipher}' found in configuration" + ) + + class ProviderSpecificConfig(BaseModel, extra="forbid"): """Base class with common provider specific configurations.""" @@ -286,6 +331,7 @@ class ProviderConfig(BaseModel): rhoai_vllm_config: Optional[RHOAIVLLMConfig] = None rhelai_vllm_config: Optional[RHELAIVLLMConfig] = None certificates_store: Optional[str] = None + tls_security_profile: Optional[TLSSecurityProfile] = None def __init__( self, @@ -333,6 +379,9 @@ def __init__( self.certificates_store = os.path.join( certificate_directory, constants.CERTIFICATE_STORAGE_FILENAME ) + self.tls_security_profile = TLSSecurityProfile( + data.get("tlsSecurityProfile", None) + ) def set_provider_type(self, data: dict) -> None: """Set the provider type.""" @@ -477,6 +526,7 @@ def __eq__(self, other: object) -> bool: and self.rhelai_vllm_config == other.rhelai_vllm_config and self.watsonx_config == other.watsonx_config and self.bam_config == other.bam_config + and self.tls_security_profile == other.tls_security_profile ) return False @@ -873,51 +923,6 @@ def check_storage_location_is_set_when_needed(self) -> Self: return self -class TLSSecurityProfile(BaseModel): - """TLS security profile structure.""" - - profile_type: Optional[str] = None - min_tls_version: Optional[str] = None - ciphers: Optional[list[str]] = None - - def __init__(self, data: Optional[dict] = None) -> None: - """Initialize configuration and perform basic validation.""" - super().__init__() - if data is not None: - self.profile_type = data.get("type") - self.min_tls_version = data.get("minTLSVersion") - self.ciphers = data.get("ciphers") - - def validate_yaml(self) -> None: - """Validate structure content.""" - # check the TLS profile type - if self.profile_type is not None: - try: - tls.TLSProfiles(self.profile_type) - except ValueError: - raise InvalidConfigurationError( - f"Invalid TLS profile type '{self.profile_type}'" - ) - # check the TLS protocol version - if self.min_tls_version is not None: - try: - tls.TLSProtocolVersion(self.min_tls_version) - except ValueError: - raise InvalidConfigurationError( - f"Invalid minimal TLS version '{self.min_tls_version}'" - ) - # check ciphers - if self.ciphers is not None: - # just perform the check for non-custom TLS profile type - if self.profile_type is not None and self.profile_type != "Custom": - supported_ciphers = tls.TLS_CIPHERS[tls.TLSProfiles(self.profile_type)] - for cipher in self.ciphers: - if cipher not in supported_ciphers: - raise InvalidConfigurationError( - f"Unsupported cipher '{cipher}' found in configuration" - ) - - class OLSConfig(BaseModel): """OLS configuration.""" @@ -931,6 +936,7 @@ class OLSConfig(BaseModel): default_provider: Optional[str] = None default_model: Optional[str] = None + max_workers: Optional[int] = None query_filters: Optional[list[QueryFilter]] = None query_validation_method: Optional[str] = constants.QueryValidationMethod.DISABLED @@ -956,6 +962,7 @@ def __init__( self.reference_content = ReferenceContent(data.get("reference_content")) self.default_provider = data.get("default_provider", None) self.default_model = data.get("default_model", None) + self.max_workers = data.get("max_workers", None) self.authentication_config = AuthenticationConfig( **data.get("authentication_config", {}) ) @@ -992,6 +999,7 @@ def __eq__(self, other: object) -> bool: and self.reference_content == other.reference_content and self.default_provider == other.default_provider and self.default_model == other.default_model + and self.max_workers == other.max_workers and self.query_filters == other.query_filters and self.query_validation_method == other.query_validation_method and self.tls_config == other.tls_config diff --git a/pyproject.toml b/pyproject.toml index cac61fc7..15082670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dev = [ "pydocstyle==6.3.0", "fastparquet==2024.5.0", # Required for model evaluation (runtime, if parquet qna file is used) "httpx==0.27.0", - "mypy==1.12.0", # Usually needs to be set to latest stable version available + "mypy==1.12.1", # Usually needs to be set to latest stable version available "pytest==8.3.2", "pytest-cov==5.0.0", "pytest-asyncio==0.24.0", diff --git a/runner.py b/runner.py index f4581b47..f7c8e936 100644 --- a/runner.py +++ b/runner.py @@ -125,7 +125,7 @@ def start_uvicorn(): "ols.app.main:app", host=host, port=port, - workers=1, + workers=config.ols_config.max_workers, log_level=log_level, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, diff --git a/tests/benchmarks/test_config_loader.py b/tests/benchmarks/test_config_loader.py index a6ddad23..2a0248f3 100644 --- a/tests/benchmarks/test_config_loader.py +++ b/tests/benchmarks/test_config_loader.py @@ -72,6 +72,7 @@ def read_valid_config_stream(): - name: m3 url: 'https://murl3' ols_config: + max_workers: 2 conversation_cache: type: memory memory: diff --git a/tests/config/config_for_integration_tests.yaml b/tests/config/config_for_integration_tests.yaml index 792319ba..80041ef5 100644 --- a/tests/config/config_for_integration_tests.yaml +++ b/tests/config/config_for_integration_tests.yaml @@ -16,6 +16,12 @@ llm_providers: max_tokens_for_response: 100 - name: m2 url: "https://murl2" + tlsSecurityProfile: + type: Custom + ciphers: + - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + minTLSVersion: VersionTLS13 - name: p2 type: openai url: "https://url2" @@ -93,6 +99,12 @@ ols_config: tls_certificate_path: tests/config/empty_cert.crt tls_key_path: tests/config/key tls_key_password_path: tests/config/password + tlsSecurityProfile: + type: Custom + ciphers: + - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + minTLSVersion: VersionTLS13 dev_config: enable_dev_ui: true disable_auth: true diff --git a/tests/config/valid_config.yaml b/tests/config/valid_config.yaml index a2266ebe..9b032dcd 100644 --- a/tests/config/valid_config.yaml +++ b/tests/config/valid_config.yaml @@ -22,6 +22,7 @@ llm_providers: - name: m2 url: "https://murl2" ols_config: + max_workers: 1 reference_content: product_docs_index_path: "tests/config" product_docs_index_id: product diff --git a/tests/unit/app/models/test_config.py b/tests/unit/app/models/test_config.py index b7039bf9..dea5b5b2 100644 --- a/tests/unit/app/models/test_config.py +++ b/tests/unit/app/models/test_config.py @@ -231,6 +231,7 @@ def test_provider_config(): assert provider_config.azure_config is None assert provider_config.watsonx_config is None assert provider_config.bam_config is None + assert provider_config.tls_security_profile is None with pytest.raises(InvalidConfigurationError) as excinfo: ProviderConfig( @@ -260,6 +261,45 @@ def test_provider_config(): assert "model name is missing" in str(excinfo.value) +def test_provider_config_with_tls_security_profile(): + """Test the ProviderConfig model.""" + provider_config = ProviderConfig( + { + "name": "test_name", + "type": "bam", + "url": "test_url", + "credentials_path": "tests/config/secret/apitoken", + "project_id": "test_project_id", + "models": [ + { + "name": "test_model_name", + "url": "http://test.url/", + "credentials_path": "tests/config/secret/apitoken", + } + ], + "tlsSecurityProfile": { + "type": "Custom", + "minTLSVersion": "VersionTLS13", + "ciphers": [ + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + ], + }, + } + ) + assert provider_config.tls_security_profile is not None + assert provider_config.tls_security_profile.profile_type == "Custom" + assert provider_config.tls_security_profile.min_tls_version == "VersionTLS13" + assert ( + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" + in provider_config.tls_security_profile.ciphers + ) + assert ( + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" + in provider_config.tls_security_profile.ciphers + ) + + def test_that_url_is_required_provider_parameter(): """Test that provider-specific URL is required attribute.""" # provider type is set to "azure_openai" @@ -2010,6 +2050,60 @@ def test_ols_config(tmpdir): assert ols_config.certificate_directory == constants.DEFAULT_CERTIFICATE_DIRECTORY assert ols_config.system_prompt_path is None assert ols_config.system_prompt is None + assert ols_config.tls_security_profile == TLSSecurityProfile() + + +def test_ols_config_with_tls_security_profile(tmpdir): + """Test the OLSConfig model.""" + ols_config = OLSConfig( + { + "default_provider": "test_default_provider", + "default_model": "test_default_model", + "conversation_cache": { + "type": "memory", + "memory": { + "max_entries": 100, + }, + }, + "logging_config": { + "logging_level": "INFO", + }, + "tlsSecurityProfile": { + "type": "Custom", + "minTLSVersion": "VersionTLS13", + "ciphers": [ + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + ], + }, + } + ) + assert ols_config.default_provider == "test_default_provider" + assert ols_config.default_model == "test_default_model" + assert ols_config.conversation_cache.type == "memory" + assert ols_config.conversation_cache.memory.max_entries == 100 + assert ols_config.logging_config.app_log_level == logging.INFO + assert ( + ols_config.query_validation_method == constants.QueryValidationMethod.DISABLED + ) + assert ols_config.user_data_collection == UserDataCollection() + assert ols_config.reference_content is None + assert ols_config.authentication_config == AuthenticationConfig() + assert ols_config.extra_ca == [] + assert ols_config.certificate_directory == constants.DEFAULT_CERTIFICATE_DIRECTORY + assert ols_config.system_prompt_path is None + assert ols_config.system_prompt is None + assert ols_config.tls_security_profile is not None + assert ols_config.tls_security_profile.profile_type == "Custom" + assert ols_config.tls_security_profile.min_tls_version == "VersionTLS13" + assert ( + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" + in ols_config.tls_security_profile.ciphers + ) + assert ( + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" + in ols_config.tls_security_profile.ciphers + ) def test_config(): diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py index f16641ac..d3fd71df 100644 --- a/tests/unit/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -740,6 +740,12 @@ def test_valid_config_stream(): credentials_path: tests/config/secret/apitoken - name: m2 url: 'https://murl2' + tlsSecurityProfile: + type: Custom + ciphers: + - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + minTLSVersion: VersionTLS13 - name: p2 type: bam url: 'https://url2' @@ -759,6 +765,12 @@ def test_valid_config_stream(): default_model: m1 certificate_directory: '/foo/bar/baz' system_prompt_path: 'tests/config/system_prompt.txt' + tlsSecurityProfile: + type: Custom + ciphers: + - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + minTLSVersion: VersionTLS13 dev_config: enable_dev_ui: true disable_auth: false @@ -817,6 +829,7 @@ def test_valid_config_file(): }, ], "ols_config": { + "max_workers": 1, "reference_content": { "product_docs_index_path": "tests/config", "product_docs_index_id": "product",