1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5import {
6 Agent,
7 type AgentEvent,
8 type AgentTool,
9} from "@mariozechner/pi-agent-core";
10import { type AssistantMessage, getEnvApiKey } from "@mariozechner/pi-ai";
11import type { RumiloConfig } from "../config/schema.js";
12import { resolveConfigValue } from "../util/env.js";
13import { AgentError } from "../util/errors.js";
14import { resolveModel } from "./model-resolver.js";
15
16export interface AgentRunOptions {
17 model: string;
18 systemPrompt: string;
19 tools: AgentTool[];
20 onEvent?: (event: AgentEvent) => void;
21 config: RumiloConfig;
22}
23
24export interface AgentRunResult {
25 message: string;
26 usage?: unknown;
27 requestCount: number;
28}
29
30/**
31 * Build a getApiKey callback for the Agent.
32 *
33 * Resolution order:
34 * 1. Custom model config — if a custom model for this provider defines an
35 * `apiKey` field, resolve it via `resolveConfigValue` (supports env var
36 * names, `$VAR` references, and `!shell` commands).
37 * 2. pi-ai’s built-in env-var lookup (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.).
38 */
39export function buildGetApiKey(
40 config: RumiloConfig,
41): (provider: string) => string | undefined {
42 return (provider: string) => {
43 if (config.custom_models) {
44 for (const model of Object.values(config.custom_models)) {
45 if (model.provider === provider && model.api_key) {
46 return resolveConfigValue(model.api_key);
47 }
48 }
49 }
50
51 return getEnvApiKey(provider);
52 };
53}
54
55export async function runAgent(
56 query: string,
57 options: AgentRunOptions,
58): Promise<AgentRunResult> {
59 const agent = new Agent({
60 initialState: {
61 systemPrompt: options.systemPrompt,
62 model: resolveModel(options.model, options.config),
63 tools: options.tools,
64 },
65 getApiKey: buildGetApiKey(options.config),
66 });
67
68 if (options.onEvent) {
69 agent.subscribe(options.onEvent);
70 }
71
72 await agent.prompt(query);
73
74 // Check for errors in agent state
75 if (agent.state.error) {
76 throw new AgentError(agent.state.error);
77 }
78
79 const last = agent.state.messages
80 .slice()
81 .reverse()
82 .find((msg): msg is AssistantMessage => msg.role === "assistant");
83
84 // Check if the last assistant message indicates an error
85 if (last?.stopReason === "error") {
86 throw new AgentError(
87 last.errorMessage ?? "Agent stopped with an unknown error",
88 );
89 }
90
91 const text = last?.content
92 ?.filter(
93 (content): content is Extract<typeof content, { type: "text" }> =>
94 content.type === "text",
95 )
96 .map((content) => content.text)
97 .join("")
98 .trim();
99
100 if (text === undefined || text === "") {
101 throw new AgentError("Agent returned no text response");
102 }
103
104 const requestCount = agent.state.messages.filter(
105 (msg) => msg.role === "assistant",
106 ).length;
107
108 return {
109 message: text,
110 usage: last?.usage,
111 requestCount,
112 };
113}