Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions openkb/agent/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from openkb import frontmatter
from openkb.config import (
DEFAULT_ENTITY_TYPES,
get_api_base,
get_extra_headers,
get_timeout,
resolve_entity_types,
Expand Down Expand Up @@ -334,6 +335,9 @@ def _llm_call(model: str, messages: list[dict], step_name: str, **kwargs) -> str
timeout = get_timeout()
if timeout is not None:
kwargs.setdefault("timeout", timeout)
api_base = get_api_base()
if api_base is not None:
kwargs.setdefault("api_base", api_base)
logger.debug("LLM request [%s]:\n%s", step_name, _fmt_messages(messages))
if kwargs:
logger.debug("LLM kwargs [%s]: %s", step_name, kwargs)
Expand All @@ -359,6 +363,9 @@ async def _llm_call_async(model: str, messages: list[dict], step_name: str, **kw
timeout = get_timeout()
if timeout is not None:
kwargs.setdefault("timeout", timeout)
api_base = get_api_base()
if api_base is not None:
kwargs.setdefault("api_base", api_base)
logger.debug("LLM request [%s]:\n%s", step_name, _fmt_messages(messages))
if kwargs:
logger.debug("LLM kwargs [%s]: %s", step_name, kwargs)
Expand Down
12 changes: 10 additions & 2 deletions openkb/agent/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from agents import Agent, Runner, function_tool

from agents import ToolOutputImage, ToolOutputText
from openkb.config import get_extra_headers, get_timeout_extra_args
from openkb.config import get_api_base, get_extra_headers, get_timeout_extra_args
from openkb.agent.tools import (
get_wiki_page_content,
read_wiki_file,
Expand Down Expand Up @@ -90,6 +90,14 @@ def get_image(image_path: str) -> ToolOutputImage | ToolOutputText:

from agents.model_settings import ModelSettings

extra_args: dict = {}
timeout_args = get_timeout_extra_args()
if timeout_args:
extra_args.update(timeout_args)
api_base = get_api_base()
if api_base is not None:
extra_args["api_base"] = api_base

return Agent(
name="wiki-query",
instructions=instructions,
Expand All @@ -98,7 +106,7 @@ def get_image(image_path: str) -> ToolOutputImage | ToolOutputText:
model_settings=ModelSettings(
parallel_tool_calls=False,
extra_headers=get_extra_headers() or None,
extra_args=get_timeout_extra_args(),
extra_args=extra_args or None,
),
)

Expand Down
20 changes: 17 additions & 3 deletions openkb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def filter(self, record: logging.LogRecord) -> bool:
from openkb.config import (
DEFAULT_CONFIG, load_config, save_config, load_global_config, register_kb,
resolve_extra_headers, set_extra_headers, resolve_timeout, set_timeout,
resolve_litellm_settings,
resolve_litellm_settings, resolve_api_base, set_api_base,
)
from openkb.converter import _registry_path, convert_document
from openkb.indexer import import_cloud_document
Expand Down Expand Up @@ -138,11 +138,12 @@ def _setup_llm_key(kb_dir: Path | None = None) -> None:

api_key = os.environ.get("LLM_API_KEY", "")

# Try to resolve the active provider, extra headers, and request timeout
# from the KB config
# Try to resolve the active provider, extra headers, request timeout, and
# custom API base URL from the KB config
provider: str | None = None
extra_headers: dict[str, str] = {}
timeout: float | None = None
api_base: str | None = None
litellm_settings: dict = {}
if kb_dir is not None:
config_path = kb_dir / ".openkb" / "config.yaml"
Expand All @@ -152,6 +153,7 @@ def _setup_llm_key(kb_dir: Path | None = None) -> None:
provider = _extract_provider(str(model))
extra_headers = resolve_extra_headers(config)
timeout = resolve_timeout(config)
api_base = resolve_api_base(config)
litellm_settings = resolve_litellm_settings(config)
# `timeout` / `extra_headers` in the block route to the per-call
# stashes (replacing the legacy top-level keys); the rest are globals.
Expand All @@ -163,8 +165,20 @@ def _setup_llm_key(kb_dir: Path | None = None) -> None:
timeout = resolve_timeout(
{"timeout": litellm_settings.pop("timeout")}
)

# Fall back to LLM_API_BASE env var when api_base is not set in config.
if api_base is None:
env_api_base = os.environ.get("LLM_API_BASE", "").strip()
if env_api_base:
api_base = env_api_base

set_extra_headers(extra_headers)
set_timeout(timeout)
set_api_base(api_base)
# Also set litellm.api_base globally so third-party libraries that call
# litellm directly (e.g. PageIndex) also route through the custom endpoint.
if api_base is not None:
litellm.api_base = api_base
_apply_litellm_settings(litellm_settings)

if not api_key:
Expand Down
38 changes: 38 additions & 0 deletions openkb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,44 @@ def get_timeout_extra_args() -> dict[str, float] | None:
return {"timeout": _runtime_timeout} if _runtime_timeout is not None else None


def resolve_api_base(config: dict) -> str | None:
"""Resolve the optional ``api_base:`` config key to a non-empty URL string.

Returns ``None`` (use the provider default) when absent or blank; warns and
returns ``None`` when present but not a string.
"""
raw = config.get("api_base")
if raw is None:
return None
if not isinstance(raw, str):
logger.warning(
"config: 'api_base' must be a URL string, got %s — ignoring it.",
type(raw).__name__,
)
return None
stripped = raw.strip()
if not stripped:
return None
return stripped


# Process-wide custom API base URL for LLM requests, set from config / env by
# the CLI entry points and read at call sites via get_api_base(). None means
# use the provider's default endpoint — behaviour is identical to not setting it.
_runtime_api_base: str | None = None


def set_api_base(api_base: str | None) -> None:
"""Set the process-wide custom API base URL; ``None`` clears it."""
global _runtime_api_base
_runtime_api_base = api_base or None


def get_api_base() -> str | None:
"""Return the process-wide custom API base URL, or ``None``."""
return _runtime_api_base


def load_config(config_path: Path) -> dict[str, Any]:
"""Load YAML config from config_path, merged with DEFAULT_CONFIG.

Expand Down