1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
   | from dataclasses import dataclass, field from typing import Any, Dict, Callable, Optional, List import time, json, random import asyncio
 
  @dataclass class ToolSpec:     name: str     schema: Dict[str, Any]       runner: Callable[[Dict[str, Any]], Any]     timeout_s: float = 8.0     retry: int = 2
  class SchemaError(ValueError):     pass
  def validate(schema: Dict[str, Any], payload: Dict[str, Any]) -> Dict[str, Any]:     required = schema.get("required", [])     props = schema.get("properties", {})     for k in required:         if k not in payload:             raise SchemaError(f"missing field: {k}")          for k, v in props.items():         if k not in payload and "default" in v:             payload[k] = v["default"]     return payload
  async def run_with_retry(tool: ToolSpec, payload: Dict[str, Any]) -> Any:     payload = validate(tool.schema, dict(payload))     delay = 0.5     for attempt in range(tool.retry + 1):         try:             return await asyncio.wait_for(asyncio.to_thread(tool.runner, payload), timeout=tool.timeout_s)         except Exception as e:             if attempt >= tool.retry:                 raise             await asyncio.sleep(delay + random.random()*0.2)             delay = min(delay * 2, 3.0)
 
  @dataclass class Budget:     tokens: int     ms: int     start: float = field(default_factory=lambda: time.time())     def left_ms(self) -> int:         return int(self.ms - (time.time() - self.start) * 1000)
  @dataclass class TraceEvent:     name: str     at: float     meta: Dict[str, Any]
  class Tracer:     def __init__(self):         self.events: List[TraceEvent] = []     def log(self, name: str, **meta):         self.events.append(TraceEvent(name, time.time(), meta))     def dump(self) -> List[Dict[str, Any]]:         return [dict(name=e.name, at=e.at, meta=e.meta) for e in self.events]
 
  class Planner:     def decide(self, query: str) -> Dict[str, Any]:         need_search = any(k in query for k in ["规范","流程","价格","说明"])         tools = ["search"] if need_search else []         return {"tools": tools, "k_refs": 2}
  class DummyLLM:     def generate(self, prompt: str, max_tokens: int = 500) -> str:                  return "答复:请参考[1][2],并已创建日程。"
  class Agent:     def __init__(self, tools: Dict[str, ToolSpec], llm: DummyLLM):         self.tools = tools         self.llm = llm         self.planner = Planner()         self.tracer = Tracer()
      async def run(self, query: str) -> Dict[str, Any]:         plan = self.planner.decide(query)         budget = Budget(tokens=3000, ms=3000)         self.tracer.log("plan", plan=plan)
          evidences = []         for t in plan["tools"]:             if budget.left_ms() < 400: break             self.tracer.log("tool.call", name=t)             res = await run_with_retry(self.tools[t], {"q": query, "k": 4})             evidences.extend(res)             self.tracer.log("tool.ok", name=t, size=len(res))
                   ctx = "\n".join([f"[{i+1}] {e['title']}" for i, e in enumerate(evidences[:6])])         prompt = f"基于证据回答并在结尾引用:[示例]\n{ctx}\n问题:{query}"         ans = self.llm.generate(prompt)         used = [int(x) for x in __import__('re').findall(r"\[(\d+)\]", ans)]         if len(used) < plan["k_refs"]:                          ans = self.llm.generate(prompt[:600])         return {"answer": ans, "trace": self.tracer.dump()}
 
  def search_runner(payload: Dict[str, Any]):     q = payload["q"]     k = payload.get("k", 4)          return [{"title": f"{q}-证据-{i+1}", "url": f"https://kb/{i+1}"} for i in range(k)]
  search_tool = ToolSpec(     name="search",     schema={         "type": "object",         "properties": {             "q": {"type": "string"},             "k": {"type": "integer", "default": 4}         },         "required": ["q"]     },     runner=search_runner,     timeout_s=2.5,     retry=1 )
 
  async def demo():     agent = Agent(tools={"search": search_tool}, llm=DummyLLM())     out = await agent.run("发布流程规范与审批")     print(json.dumps(out, ensure_ascii=False, indent=2))
  if __name__ == "__main__":     asyncio.run(demo())
   |