runner.ts

 1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
 2//
 3// SPDX-License-Identifier: GPL-3.0-or-later
 4
 5import { Agent, type AgentEvent, type AgentTool } from "@mariozechner/pi-agent-core";
 6import { getEnvApiKey, type AssistantMessage } from "@mariozechner/pi-ai";
 7import type { RumiloConfig } from "../config/schema.js";
 8import { resolveModel } from "./model-resolver.js";
 9import { AgentError } from "../util/errors.js";
10import { resolveConfigValue } from "../util/env.js";
11
12export interface AgentRunOptions {
13  model: string;
14  systemPrompt: string;
15  tools: AgentTool[];
16  onEvent?: (event: AgentEvent) => void;
17  config: RumiloConfig;
18}
19
20export interface AgentRunResult {
21  message: string;
22  usage?: unknown;
23  requestCount: number;
24}
25
26/**
27 * Build a getApiKey callback for the Agent.
28 *
29 * Resolution order:
30 * 1. Custom model config — if a custom model for this provider defines an
31 *    `apiKey` field, resolve it via `resolveConfigValue` (supports env var
32 *    names, `$VAR` references, and `!shell` commands).
33 * 2. pi-ai’s built-in env-var lookup (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.).
34 */
35export function buildGetApiKey(config: RumiloConfig): (provider: string) => string | undefined {
36  return (provider: string) => {
37    if (config.custom_models) {
38      for (const model of Object.values(config.custom_models)) {
39        if (model.provider === provider && model.api_key) {
40          return resolveConfigValue(model.api_key);
41        }
42      }
43    }
44
45    return getEnvApiKey(provider);
46  };
47}
48
49export async function runAgent(query: string, options: AgentRunOptions): Promise<AgentRunResult> {
50  const agent = new Agent({
51    initialState: {
52      systemPrompt: options.systemPrompt,
53      model: resolveModel(options.model, options.config),
54      tools: options.tools,
55    },
56    getApiKey: buildGetApiKey(options.config),
57  });
58
59  if (options.onEvent) {
60    agent.subscribe(options.onEvent);
61  }
62
63  await agent.prompt(query);
64
65  // Check for errors in agent state
66  if (agent.state.error) {
67    throw new AgentError(agent.state.error);
68  }
69
70  const last = agent.state.messages
71    .slice()
72    .reverse()
73    .find((msg): msg is AssistantMessage => msg.role === "assistant");
74
75  // Check if the last assistant message indicates an error
76  if (last?.stopReason === "error") {
77    throw new AgentError(last.errorMessage ?? "Agent stopped with an unknown error");
78  }
79
80  const text = last?.content
81    ?.filter((content): content is Extract<typeof content, { type: "text" }> => content.type === "text")
82    .map((content) => content.text)
83    .join("")
84    .trim();
85
86  if (text === undefined || text === "") {
87    throw new AgentError("Agent completed without producing a text response");
88  }
89
90  const requestCount = agent.state.messages.filter((msg) => msg.role === "assistant").length;
91
92  return {
93    message: text,
94    usage: last?.usage,
95    requestCount,
96  };
97}