diff --git a/openkb/agent/compiler.py b/openkb/agent/compiler.py index af3ac1bd..f3261515 100644 --- a/openkb/agent/compiler.py +++ b/openkb/agent/compiler.py @@ -31,6 +31,7 @@ from openkb import frontmatter from openkb.config import ( DEFAULT_ENTITY_TYPES, + get_api_base, get_extra_headers, get_timeout, resolve_entity_types, @@ -334,6 +335,9 @@ def _llm_call(model: str, messages: list[dict], step_name: str, **kwargs) -> str timeout = get_timeout() if timeout is not None: kwargs.setdefault("timeout", timeout) + api_base = get_api_base() + if api_base is not None: + kwargs.setdefault("api_base", api_base) logger.debug("LLM request [%s]:\n%s", step_name, _fmt_messages(messages)) if kwargs: logger.debug("LLM kwargs [%s]: %s", step_name, kwargs) @@ -359,6 +363,9 @@ async def _llm_call_async(model: str, messages: list[dict], step_name: str, **kw timeout = get_timeout() if timeout is not None: kwargs.setdefault("timeout", timeout) + api_base = get_api_base() + if api_base is not None: + kwargs.setdefault("api_base", api_base) logger.debug("LLM request [%s]:\n%s", step_name, _fmt_messages(messages)) if kwargs: logger.debug("LLM kwargs [%s]: %s", step_name, kwargs) diff --git a/openkb/agent/query.py b/openkb/agent/query.py index 5a755d76..2ae64989 100644 --- a/openkb/agent/query.py +++ b/openkb/agent/query.py @@ -6,7 +6,7 @@ from agents import Agent, Runner, function_tool from agents import ToolOutputImage, ToolOutputText -from openkb.config import get_extra_headers, get_timeout_extra_args +from openkb.config import get_api_base, get_extra_headers, get_timeout_extra_args from openkb.agent.tools import ( get_wiki_page_content, read_wiki_file, @@ -90,6 +90,14 @@ def get_image(image_path: str) -> ToolOutputImage | ToolOutputText: from agents.model_settings import ModelSettings + extra_args: dict = {} + timeout_args = get_timeout_extra_args() + if timeout_args: + extra_args.update(timeout_args) + api_base = get_api_base() + if api_base is not None: + extra_args["api_base"] = api_base + return Agent( name="wiki-query", instructions=instructions, @@ -98,7 +106,7 @@ def get_image(image_path: str) -> ToolOutputImage | ToolOutputText: model_settings=ModelSettings( parallel_tool_calls=False, extra_headers=get_extra_headers() or None, - extra_args=get_timeout_extra_args(), + extra_args=extra_args or None, ), ) diff --git a/openkb/cli.py b/openkb/cli.py index 28694987..545fc391 100644 --- a/openkb/cli.py +++ b/openkb/cli.py @@ -45,7 +45,7 @@ def filter(self, record: logging.LogRecord) -> bool: from openkb.config import ( DEFAULT_CONFIG, load_config, save_config, load_global_config, register_kb, resolve_extra_headers, set_extra_headers, resolve_timeout, set_timeout, - resolve_litellm_settings, + resolve_litellm_settings, resolve_api_base, set_api_base, ) from openkb.converter import _registry_path, convert_document from openkb.indexer import import_cloud_document @@ -138,11 +138,12 @@ def _setup_llm_key(kb_dir: Path | None = None) -> None: api_key = os.environ.get("LLM_API_KEY", "") - # Try to resolve the active provider, extra headers, and request timeout - # from the KB config + # Try to resolve the active provider, extra headers, request timeout, and + # custom API base URL from the KB config provider: str | None = None extra_headers: dict[str, str] = {} timeout: float | None = None + api_base: str | None = None litellm_settings: dict = {} if kb_dir is not None: config_path = kb_dir / ".openkb" / "config.yaml" @@ -152,6 +153,7 @@ def _setup_llm_key(kb_dir: Path | None = None) -> None: provider = _extract_provider(str(model)) extra_headers = resolve_extra_headers(config) timeout = resolve_timeout(config) + api_base = resolve_api_base(config) litellm_settings = resolve_litellm_settings(config) # `timeout` / `extra_headers` in the block route to the per-call # stashes (replacing the legacy top-level keys); the rest are globals. @@ -163,8 +165,20 @@ def _setup_llm_key(kb_dir: Path | None = None) -> None: timeout = resolve_timeout( {"timeout": litellm_settings.pop("timeout")} ) + + # Fall back to LLM_API_BASE env var when api_base is not set in config. + if api_base is None: + env_api_base = os.environ.get("LLM_API_BASE", "").strip() + if env_api_base: + api_base = env_api_base + set_extra_headers(extra_headers) set_timeout(timeout) + set_api_base(api_base) + # Also set litellm.api_base globally so third-party libraries that call + # litellm directly (e.g. PageIndex) also route through the custom endpoint. + if api_base is not None: + litellm.api_base = api_base _apply_litellm_settings(litellm_settings) if not api_key: diff --git a/openkb/config.py b/openkb/config.py index 52082c62..9938270f 100644 --- a/openkb/config.py +++ b/openkb/config.py @@ -242,6 +242,44 @@ def get_timeout_extra_args() -> dict[str, float] | None: return {"timeout": _runtime_timeout} if _runtime_timeout is not None else None +def resolve_api_base(config: dict) -> str | None: + """Resolve the optional ``api_base:`` config key to a non-empty URL string. + + Returns ``None`` (use the provider default) when absent or blank; warns and + returns ``None`` when present but not a string. + """ + raw = config.get("api_base") + if raw is None: + return None + if not isinstance(raw, str): + logger.warning( + "config: 'api_base' must be a URL string, got %s — ignoring it.", + type(raw).__name__, + ) + return None + stripped = raw.strip() + if not stripped: + return None + return stripped + + +# Process-wide custom API base URL for LLM requests, set from config / env by +# the CLI entry points and read at call sites via get_api_base(). None means +# use the provider's default endpoint — behaviour is identical to not setting it. +_runtime_api_base: str | None = None + + +def set_api_base(api_base: str | None) -> None: + """Set the process-wide custom API base URL; ``None`` clears it.""" + global _runtime_api_base + _runtime_api_base = api_base or None + + +def get_api_base() -> str | None: + """Return the process-wide custom API base URL, or ``None``.""" + return _runtime_api_base + + def load_config(config_path: Path) -> dict[str, Any]: """Load YAML config from config_path, merged with DEFAULT_CONFIG.