1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5import { getModel, type Model } from "@mariozechner/pi-ai";
6import {
7 type CustomModelConfig,
8 type RumiloConfig,
9} from "../config/schema.js";
10import { ConfigError } from "../util/errors.js";
11import { resolveHeaders } from "../util/env.js";
12
13export function resolveModel(
14 modelString: string,
15 config: RumiloConfig,
16): Model<any> {
17 const colonIndex = modelString.indexOf(":");
18 if (colonIndex === -1) {
19 throw new ConfigError(`Invalid model format: "${modelString}". Expected provider:model (e.g. anthropic:claude-sonnet-4-20250514)`);
20 }
21
22 const provider = modelString.slice(0, colonIndex);
23 const modelName = modelString.slice(colonIndex + 1);
24
25 if (!provider || !modelName) {
26 throw new ConfigError(`Invalid model format: "${modelString}". Expected provider:model (e.g. anthropic:claude-sonnet-4-20250514)`);
27 }
28
29 // Handle custom models
30 if (provider === "custom") {
31 return resolveCustomModel(modelName, config);
32 }
33
34 // Handle built-in providers
35 return getModel(provider as any, modelName);
36}
37
38function resolveCustomModel(modelName: string, config: RumiloConfig): Model<any> {
39 if (!config.custom_models) {
40 throw new ConfigError(
41 `No custom models defined in config. Add a [custom_models.${modelName}] section to use custom:${modelName}`,
42 );
43 }
44
45 const customConfig = config.custom_models[modelName];
46 if (!customConfig) {
47 const available = Object.keys(config.custom_models).join(", ");
48 throw new ConfigError(
49 `Custom model '${modelName}' not found. Available custom models: ${available}`,
50 );
51 }
52
53 return buildCustomModel(customConfig);
54}
55
56function buildCustomModel(config: CustomModelConfig): Model<any> {
57 const api = config.api as any;
58
59 const cost: any = {
60 input: config.cost.input,
61 output: config.cost.output,
62 };
63
64 if (config.cost.cache_read !== undefined) {
65 cost.cacheRead = config.cost.cache_read;
66 }
67
68 if (config.cost.cache_write !== undefined) {
69 cost.cacheWrite = config.cost.cache_write;
70 }
71
72 const model: any = {
73 id: config.id,
74 name: config.name,
75 api,
76 provider: config.provider as any,
77 baseUrl: config.base_url,
78 reasoning: config.reasoning,
79 input: config.input,
80 cost,
81 contextWindow: config.context_window,
82 maxTokens: config.max_tokens,
83 };
84
85 const resolvedHeaders = resolveHeaders(config.headers);
86 if (resolvedHeaders) {
87 model.headers = resolvedHeaders;
88 }
89
90 if (config.compat) {
91 model.compat = convertCompatConfig(config.compat);
92 }
93
94 return model;
95}
96
97function convertCompatConfig(
98 compat: CustomModelConfig["compat"],
99): any {
100 if (!compat) {
101 throw new Error("Compat config is expected to be defined");
102 }
103
104 return {
105 supportsStore: compat.supports_store,
106 supportsDeveloperRole: compat.supports_developer_role,
107 supportsReasoningEffort: compat.supports_reasoning_effort,
108 supportsUsageInStreaming: compat.supports_usage_in_streaming,
109 maxTokensField: compat.max_tokens_field,
110 requiresToolResultName: compat.requires_tool_result_name,
111 requiresAssistantAfterToolResult: compat.requires_assistant_after_tool_result,
112 requiresThinkingAsText: compat.requires_thinking_as_text,
113 requiresMistralToolIds: compat.requires_mistral_tool_ids,
114 thinkingFormat: compat.thinking_format,
115 };
116}