1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5import { Agent, type AgentEvent, type AgentTool } from "@mariozechner/pi-agent-core";
6import { getEnvApiKey, type AssistantMessage } from "@mariozechner/pi-ai";
7import type { RumiloConfig } from "../config/schema.js";
8import { resolveModel } from "./model-resolver.js";
9import { AgentError } from "../util/errors.js";
10import { resolveConfigValue } from "../util/env.js";
11
12export interface AgentRunOptions {
13 model: string;
14 systemPrompt: string;
15 tools: AgentTool[];
16 onEvent?: (event: AgentEvent) => void;
17 config: RumiloConfig;
18}
19
20export interface AgentRunResult {
21 message: string;
22 usage?: unknown;
23 requestCount: number;
24}
25
26/**
27 * Build a getApiKey callback for the Agent.
28 *
29 * Resolution order:
30 * 1. Custom model config — if a custom model for this provider defines an
31 * `apiKey` field, resolve it via `resolveConfigValue` (supports env var
32 * names, `$VAR` references, and `!shell` commands).
33 * 2. pi-ai’s built-in env-var lookup (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.).
34 */
35export function buildGetApiKey(config: RumiloConfig): (provider: string) => string | undefined {
36 return (provider: string) => {
37 if (config.custom_models) {
38 for (const model of Object.values(config.custom_models)) {
39 if (model.provider === provider && model.api_key) {
40 return resolveConfigValue(model.api_key);
41 }
42 }
43 }
44
45 return getEnvApiKey(provider);
46 };
47}
48
49export async function runAgent(query: string, options: AgentRunOptions): Promise<AgentRunResult> {
50 const agent = new Agent({
51 initialState: {
52 systemPrompt: options.systemPrompt,
53 model: resolveModel(options.model, options.config),
54 tools: options.tools,
55 },
56 getApiKey: buildGetApiKey(options.config),
57 });
58
59 if (options.onEvent) {
60 agent.subscribe(options.onEvent);
61 }
62
63 await agent.prompt(query);
64
65 // Check for errors in agent state
66 if (agent.state.error) {
67 throw new AgentError(agent.state.error);
68 }
69
70 const last = agent.state.messages
71 .slice()
72 .reverse()
73 .find((msg): msg is AssistantMessage => msg.role === "assistant");
74
75 // Check if the last assistant message indicates an error
76 if (last?.stopReason === "error") {
77 throw new AgentError(last.errorMessage ?? "Agent stopped with an unknown error");
78 }
79
80 const text = last?.content
81 ?.filter((content): content is Extract<typeof content, { type: "text" }> => content.type === "text")
82 .map((content) => content.text)
83 .join("")
84 .trim();
85
86 if (text === undefined || text === "") {
87 throw new AgentError("Agent completed without producing a text response");
88 }
89
90 const requestCount = agent.state.messages.filter((msg) => msg.role === "assistant").length;
91
92 return {
93 message: text,
94 usage: last?.usage,
95 requestCount,
96 };
97}