1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5import { getModel, type Model } from "@mariozechner/pi-ai";
6import type { CustomModelConfig, RumiloConfig } from "../config/schema.js";
7import { resolveHeaders } from "../util/env.js";
8import { ConfigError } from "../util/errors.js";
9
10export function resolveModel(
11 modelString: string,
12 config: RumiloConfig,
13): Model<any> {
14 const colonIndex = modelString.indexOf(":");
15 if (colonIndex === -1) {
16 throw new ConfigError(
17 `Invalid model format: "${modelString}". Expected provider:model (e.g. anthropic:claude-sonnet-4-20250514)`,
18 );
19 }
20
21 const provider = modelString.slice(0, colonIndex);
22 const modelName = modelString.slice(colonIndex + 1);
23
24 if (!provider || !modelName) {
25 throw new ConfigError(
26 `Invalid model format: "${modelString}". Expected provider:model (e.g. anthropic:claude-sonnet-4-20250514)`,
27 );
28 }
29
30 // Handle custom models
31 if (provider === "custom") {
32 return resolveCustomModel(modelName, config);
33 }
34
35 // Handle built-in providers
36 return getModel(provider as any, modelName);
37}
38
39function resolveCustomModel(
40 modelName: string,
41 config: RumiloConfig,
42): Model<any> {
43 if (!config.custom_models) {
44 throw new ConfigError(
45 `No custom models defined in config. Add a [custom_models.${modelName}] section to use custom:${modelName}`,
46 );
47 }
48
49 const customConfig = config.custom_models[modelName];
50 if (!customConfig) {
51 const available = Object.keys(config.custom_models).join(", ");
52 throw new ConfigError(
53 `Custom model '${modelName}' not found. Available custom models: ${available}`,
54 );
55 }
56
57 return buildCustomModel(customConfig);
58}
59
60function buildCustomModel(config: CustomModelConfig): Model<any> {
61 const api = config.api as any;
62
63 const cost: any = {
64 input: config.cost.input,
65 output: config.cost.output,
66 };
67
68 if (config.cost.cache_read !== undefined) {
69 cost.cacheRead = config.cost.cache_read;
70 }
71
72 if (config.cost.cache_write !== undefined) {
73 cost.cacheWrite = config.cost.cache_write;
74 }
75
76 const model: any = {
77 id: config.id,
78 name: config.name,
79 api,
80 provider: config.provider as any,
81 baseUrl: config.base_url,
82 reasoning: config.reasoning,
83 input: config.input,
84 cost,
85 contextWindow: config.context_window,
86 maxTokens: config.max_tokens,
87 };
88
89 const resolvedHeaders = resolveHeaders(config.headers);
90 if (resolvedHeaders) {
91 model.headers = resolvedHeaders;
92 }
93
94 if (config.compat) {
95 model.compat = convertCompatConfig(config.compat);
96 }
97
98 return model;
99}
100
101function convertCompatConfig(compat: CustomModelConfig["compat"]): any {
102 if (!compat) {
103 throw new Error("Compat config is expected to be defined");
104 }
105
106 return {
107 supportsStore: compat.supports_store,
108 supportsDeveloperRole: compat.supports_developer_role,
109 supportsReasoningEffort: compat.supports_reasoning_effort,
110 supportsUsageInStreaming: compat.supports_usage_in_streaming,
111 maxTokensField: compat.max_tokens_field,
112 requiresToolResultName: compat.requires_tool_result_name,
113 requiresAssistantAfterToolResult:
114 compat.requires_assistant_after_tool_result,
115 requiresThinkingAsText: compat.requires_thinking_as_text,
116 requiresMistralToolIds: compat.requires_mistral_tool_ids,
117 thinkingFormat: compat.thinking_format,
118 };
119}