api2cursor/routes/common.py
2026-03-15 13:52:09 +08:00

429 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""路由层公共辅助
收敛多个数据面路由都会用到的上下文解析、上游目标构造、日志输出和
SSE 消息拼装逻辑,避免 `chat.py` 和 `responses.py` 各自维护重复实现。
"""
from __future__ import annotations
from dataclasses import dataclass
import hashlib
import json
import logging
import threading
import time
from typing import Any
import settings
from utils.http import build_anthropic_headers, build_gemini_headers, build_openai_headers
logger = logging.getLogger(__name__)
_RESPONSES_PREV_ID_LOCK = threading.Lock()
_RESPONSES_PREV_ID_TTL = 86400
_RESPONSES_PREV_IDS: dict[str, tuple[str, float]] = {}
@dataclass(frozen=True)
class RouteContext:
"""数据面路由使用的标准请求上下文。
路由层会先根据客户端模型名解析出统一上下文,后续处理函数只需要关心
上游模型、后端类型、目标地址、鉴权信息、流式标记和自定义指令,
而不必重复访问配置层。
"""
client_model: str
upstream_model: str
backend: str
target_url: str
api_key: str
is_stream: bool
custom_instructions: str
instructions_position: str
body_modifications: dict
header_modifications: dict
def build_route_context(client_model: str, is_stream: bool) -> RouteContext:
"""解析模型映射,得到当前请求的统一路由上下文。"""
mapping = settings.resolve_model(client_model)
return RouteContext(
client_model=client_model,
upstream_model=mapping['upstream_model'],
backend=mapping['backend'],
target_url=mapping['target_url'],
api_key=mapping['api_key'],
is_stream=is_stream,
custom_instructions=mapping.get('custom_instructions', ''),
instructions_position=mapping.get('instructions_position', 'prepend'),
body_modifications=mapping.get('body_modifications', {}),
header_modifications=mapping.get('header_modifications', {}),
)
def build_openai_target(ctx: RouteContext) -> tuple[str, dict[str, str]]:
"""根据路由上下文生成 OpenAI 兼容后端的地址和请求头。"""
url = f'{ctx.target_url.rstrip("/")}/v1/chat/completions'
headers = build_openai_headers(ctx.api_key)
return url, headers
def build_responses_target(ctx: RouteContext) -> tuple[str, dict[str, str]]:
"""根据路由上下文生成 OpenAI Responses 后端的地址和请求头。"""
url = f'{ctx.target_url.rstrip("/")}/v1/responses'
headers = build_openai_headers(ctx.api_key)
return url, headers
def build_anthropic_target(ctx: RouteContext) -> tuple[str, dict[str, str]]:
"""根据路由上下文生成 Anthropic 后端的地址和请求头。"""
url = f'{ctx.target_url.rstrip("/")}/v1/messages'
headers = build_anthropic_headers(ctx.api_key)
return url, headers
def build_gemini_target(ctx: RouteContext, stream: bool = False) -> tuple[str, dict[str, str]]:
"""根据路由上下文生成 Gemini 后端的地址和请求头。
Gemini URL 格式: {base}/v1/models/{model}:generateContent
流式: {base}/v1/models/{model}:streamGenerateContent?alt=sse
"""
base = ctx.target_url.rstrip('/')
model = ctx.upstream_model
if stream:
url = f'{base}/v1/models/{model}:streamGenerateContent?alt=sse'
else:
url = f'{base}/v1/models/{model}:generateContent'
headers = build_gemini_headers(ctx.api_key)
return url, headers
def log_route_context(route_name: str, ctx: RouteContext, *, extra: str = '') -> None:
"""统一输出路由级日志,避免不同入口的日志格式逐渐漂移。"""
parts = [
f'[{route_name}]',
f'模型={ctx.client_model}',
f'上游模型={ctx.upstream_model}',
f'后端={ctx.backend}',
f'流式={ctx.is_stream}',
]
if extra:
parts.append(extra)
logger.info(' '.join(parts))
def log_usage(
route_name: str,
usage: dict[str, Any],
*,
input_key: str,
output_key: str,
) -> None:
"""统一输出令牌统计日志。
不同协议对 usage 字段命名不一致,这里只接收字段名,不在调用方重复拼接日志文案。
"""
logger.info(
'[%s] 请求完成 输入令牌=%s 输出令牌=%s',
route_name,
usage.get(input_key, 0),
usage.get(output_key, 0),
)
def sse_data_message(data: Any) -> str:
"""构造仅包含 data 的 SSE 消息。"""
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
return f'data: {payload}\n\n'
def sse_event_message(event_type: str, data: Any) -> str:
"""构造带 event 名称的 SSE 消息。"""
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
return f'event: {event_type}\ndata: {payload}\n\n'
def chat_error_chunk(message: str, error_type: str = 'upstream_error') -> str:
"""构造聊天补全流式接口使用的错误消息。"""
return sse_data_message({'error': {'message': message, 'type': error_type}})
def responses_error_event(message: str) -> str:
"""构造 Responses 流式接口使用的错误事件。"""
return sse_event_message('error', {'error': message})
# ─── 自定义指令注入 ──────────────────────────────
def _merge_text(custom: str, existing: str, position: str) -> str:
"""根据 position 决定自定义指令与原有内容的拼接顺序。"""
if not existing:
return custom
if position == 'append':
return existing + '\n\n' + custom
return custom + '\n\n' + existing
def inject_instructions_cc(payload: dict[str, Any], instructions: str, position: str = 'prepend') -> dict[str, Any]:
"""向 Chat Completions 请求注入自定义指令。
position='prepend' 时放在 system 消息开头,'append' 时放在末尾。
"""
if not instructions:
return payload
messages = payload.get('messages', [])
if messages and messages[0].get('role') == 'system':
first = messages[0]
original = first.get('content') or ''
first['content'] = _merge_text(instructions, original, position)
else:
messages.insert(0, {'role': 'system', 'content': instructions})
payload['messages'] = messages
logger.info('已注入自定义指令到 CC system 消息 (%d 字符, %s)', len(instructions), position)
return payload
def inject_instructions_responses(payload: dict[str, Any], instructions: str, position: str = 'prepend') -> dict[str, Any]:
"""向 Responses 请求注入自定义指令(写入 instructions 字段)。
position='prepend' 时放在 instructions 开头,'append' 时放在末尾。
"""
if not instructions:
return payload
existing = payload.get('instructions') or ''
payload['instructions'] = _merge_text(instructions, existing, position)
logger.info('已注入自定义指令到 Responses instructions (%d 字符, %s)', len(instructions), position)
return payload
def ensure_responses_cache_control(payload: dict[str, Any]) -> dict[str, Any]:
"""为 Responses 请求补齐自动 prompt caching 开关。
一些支持 `/v1/responses` 的上游会参考顶层 `cache_control` 来自动放置缓存断点。
Cursor 侧通常不会主动携带这个字段,因此这里在缺失时补一个保守的默认值,
同时允许调用方通过 body_modifications 或显式字段自行覆盖/关闭。
"""
if not isinstance(payload, dict):
return payload
cache_control = payload.get('cache_control')
if isinstance(cache_control, dict) and cache_control.get('type'):
return payload
payload['cache_control'] = {'type': 'ephemeral'}
logger.info('已为 Responses 请求自动启用 cache_control=ephemeral')
return payload
def attach_previous_response_id(payload: dict[str, Any]) -> dict[str, Any]:
"""为多轮 Responses 请求补齐上一轮 response_id。
某些上游在 `/v1/responses` 多轮场景下,只有沿用 `previous_response_id` 才能稳定复用
上一轮的服务端响应链与缓存。Cursor 通常会回传完整历史,但不会主动带这个字段,
因此代理需要基于稳定对话键做一次轻量补齐。
"""
if not isinstance(payload, dict) or payload.get('previous_response_id'):
return payload
key = _responses_prev_id_key(payload)
if not key:
return payload
previous_response_id = _get_previous_response_id(key)
if not previous_response_id:
return payload
payload['previous_response_id'] = previous_response_id
logger.info('已为 Responses 请求补齐 previous_response_id')
return payload
def remember_response_id(payload: dict[str, Any], response_data: dict[str, Any]) -> None:
"""记住当前对话最近一次上游 Responses response_id。"""
if not isinstance(payload, dict) or not isinstance(response_data, dict):
return
response_id = response_data.get('id')
if not isinstance(response_id, str) or not response_id.strip():
return
key = _responses_prev_id_key(payload)
if not key:
return
with _RESPONSES_PREV_ID_LOCK:
_RESPONSES_PREV_IDS[key] = (response_id.strip(), time.time())
_cleanup_previous_response_ids_locked()
def _responses_prev_id_key(payload: dict[str, Any]) -> str:
"""基于 Responses 请求的“对话根信息”生成稳定键。
这里故意不直接使用完整 `input` 作为键,因为多轮对话每轮都会追加历史;
如果把整段历史都纳入哈希,键会在每一轮变化,导致无法稳定取回上一轮的
`previous_response_id`。当前策略只取 instructions 与首轮 user/assistant 根消息。
"""
instructions = payload.get('instructions') or ''
input_data = payload.get('input', [])
if isinstance(input_data, str):
seed_input = input_data
elif isinstance(input_data, list):
seed_input = _responses_root_seed_from_items(input_data)
else:
seed_input = json.dumps(input_data, ensure_ascii=False, default=str)
raw = instructions + '|' + seed_input
if not raw.strip('|'):
return ''
return hashlib.sha256(raw.encode('utf-8')).hexdigest()[:24]
def _responses_root_seed_from_items(items: list[Any]) -> str:
"""从 Responses `input` 中提取足够稳定的对话根片段。
目标不是完整还原会话,而是构造一个在同一段对话内尽量恒定、跨轮次可复用的
seed。这里沿用项目里 conversation seed 的思路:优先取第一条 user 与第一条
assistant如果 assistant 还不存在,则只用第一条 user。
"""
first_user = None
first_assistant = None
for item in items:
if isinstance(item, str):
if first_user is None:
first_user = {'role': 'user', 'content': item}
continue
if not isinstance(item, dict):
continue
item_type = item.get('type', '')
role = item.get('role', '')
if item_type == 'message' and role in ('user', 'assistant'):
normalized = {
'role': role,
'content': _responses_normalize_content(item.get('content', [])),
}
if role == 'user' and first_user is None:
first_user = normalized
elif role == 'assistant' and first_assistant is None:
first_assistant = normalized
elif role in ('user', 'assistant') and not item_type:
normalized = {
'role': role,
'content': _responses_normalize_content(item.get('content', '')),
}
if role == 'user' and first_user is None:
first_user = normalized
elif role == 'assistant' and first_assistant is None:
first_assistant = normalized
if first_user is not None and first_assistant is not None:
break
parts = []
if first_user is not None:
parts.append(first_user)
if first_assistant is not None:
parts.append(first_assistant)
return json.dumps(parts, ensure_ascii=False, separators=(',', ':'))
def _responses_normalize_content(content: Any) -> str:
"""把 Responses 各种 content 形态折叠成稳定文本。
这里的目标不是保真展示,而是降低结构差异对 key 计算的影响;只抽取会影响
会话根语义的文本型内容,忽略无关字段,避免同一轮请求因格式细节不同而得到
不同的 previous_response_id 键。
"""
if isinstance(content, str):
return content.strip()
if not isinstance(content, list):
return str(content).strip() if content is not None else ''
texts: list[str] = []
for part in content:
if isinstance(part, str):
texts.append(part)
continue
if not isinstance(part, dict):
continue
if part.get('type') in ('input_text', 'output_text', 'text'):
texts.append(part.get('text', ''))
elif part.get('type') == 'summary_text':
texts.append(part.get('text', ''))
return '\n'.join(texts).strip()
def _get_previous_response_id(key: str) -> str:
"""按稳定键读取上一轮 response_id并在过期时顺手清理。"""
with _RESPONSES_PREV_ID_LOCK:
entry = _RESPONSES_PREV_IDS.get(key)
if not entry:
return ''
response_id, ts = entry
if (time.time() - ts) >= _RESPONSES_PREV_ID_TTL:
_RESPONSES_PREV_IDS.pop(key, None)
return ''
return response_id
def _cleanup_previous_response_ids_locked() -> None:
"""清理过期的 previous_response_id 缓存项。
这张表只用于短期多轮续接;一旦对话长时间不活跃,就不再需要继续保留,
以免常驻进程运行过久后累计过多失效状态。
"""
now = time.time()
expired = [
key for key, (_, ts) in _RESPONSES_PREV_IDS.items()
if (now - ts) >= _RESPONSES_PREV_ID_TTL
]
for key in expired:
_RESPONSES_PREV_IDS.pop(key, None)
def inject_instructions_anthropic(payload: dict[str, Any], instructions: str, position: str = 'prepend') -> dict[str, Any]:
"""向 Anthropic Messages 请求注入自定义指令(写入 system 字段)。
position='prepend' 时放在 system 开头,'append' 时放在末尾。
"""
if not instructions:
return payload
existing = payload.get('system') or ''
if isinstance(existing, list):
existing = '\n'.join(
block.get('text', '') for block in existing
if isinstance(block, dict) and block.get('type') == 'text'
)
payload['system'] = _merge_text(instructions, existing, position)
logger.info('已注入自定义指令到 Anthropic system (%d 字符, %s)', len(instructions), position)
return payload
# ─── Body / Header 修改 ──────────────────────────
def apply_body_modifications(payload: dict[str, Any], modifications: dict[str, Any]) -> dict[str, Any]:
"""对转发请求体应用字段级修改。
规则与 CursorProxy 一致:值为 null 的字段会被删除,其余字段设置/覆盖。
"""
if not modifications:
return payload
for key, value in modifications.items():
if value is None:
payload.pop(key, None)
else:
payload[key] = value
logger.info('已应用 body_modifications: %s', list(modifications.keys()))
return payload
def apply_header_modifications(headers: dict[str, str], modifications: dict[str, Any]) -> dict[str, str]:
"""对转发请求头应用字段级修改。
规则同 body值为 null 删除,其余设置/覆盖。
"""
if not modifications:
return headers
for key, value in modifications.items():
if value is None:
headers.pop(key, None)
else:
headers[key] = str(value)
logger.info('已应用 header_modifications: %s', list(modifications.keys()))
return headers