1import { getModel, type Model } from "@mariozechner/pi-ai";
2import {
3 type CustomModelConfig,
4 type RumiloConfig,
5} from "../config/schema.js";
6import { ConfigError } from "../util/errors.js";
7import { resolveHeaders } from "../util/env.js";
8
9export function resolveModel(
10 modelString: string,
11 config: RumiloConfig,
12): Model<any> {
13 const colonIndex = modelString.indexOf(":");
14 if (colonIndex === -1) {
15 throw new ConfigError("Model must be in provider:model format");
16 }
17
18 const provider = modelString.slice(0, colonIndex);
19 const modelName = modelString.slice(colonIndex + 1);
20
21 if (!provider || !modelName) {
22 throw new ConfigError("Model must be in provider:model format");
23 }
24
25 // Handle custom models
26 if (provider === "custom") {
27 return resolveCustomModel(modelName, config);
28 }
29
30 // Handle built-in providers
31 return getModel(provider as any, modelName);
32}
33
34function resolveCustomModel(modelName: string, config: RumiloConfig): Model<any> {
35 if (!config.custom_models) {
36 throw new ConfigError(
37 `No custom models configured. Use 'custom:' prefix only with custom model definitions in config.`,
38 );
39 }
40
41 const customConfig = config.custom_models[modelName];
42 if (!customConfig) {
43 const available = Object.keys(config.custom_models).join(", ");
44 throw new ConfigError(
45 `Custom model '${modelName}' not found. Available custom models: ${available}`,
46 );
47 }
48
49 return buildCustomModel(customConfig);
50}
51
52function buildCustomModel(config: CustomModelConfig): Model<any> {
53 const api = config.api as any;
54
55 const cost: any = {
56 input: config.cost.input,
57 output: config.cost.output,
58 };
59
60 if (config.cost.cache_read !== undefined) {
61 cost.cacheRead = config.cost.cache_read;
62 }
63
64 if (config.cost.cache_write !== undefined) {
65 cost.cacheWrite = config.cost.cache_write;
66 }
67
68 const model: any = {
69 id: config.id,
70 name: config.name,
71 api,
72 provider: config.provider as any,
73 baseUrl: config.base_url,
74 reasoning: config.reasoning,
75 input: config.input,
76 cost,
77 contextWindow: config.context_window,
78 maxTokens: config.max_tokens,
79 };
80
81 const resolvedHeaders = resolveHeaders(config.headers);
82 if (resolvedHeaders) {
83 model.headers = resolvedHeaders;
84 }
85
86 if (config.compat) {
87 model.compat = convertCompatConfig(config.compat);
88 }
89
90 return model;
91}
92
93function convertCompatConfig(
94 compat: CustomModelConfig["compat"],
95): any {
96 if (!compat) {
97 throw new Error("Compat config is expected to be defined");
98 }
99
100 return {
101 supportsStore: compat.supports_store,
102 supportsDeveloperRole: compat.supports_developer_role,
103 supportsReasoningEffort: compat.supports_reasoning_effort,
104 supportsUsageInStreaming: compat.supports_usage_in_streaming,
105 maxTokensField: compat.max_tokens_field,
106 requiresToolResultName: compat.requires_tool_result_name,
107 requiresAssistantAfterToolResult: compat.requires_assistant_after_tool_result,
108 requiresThinkingAsText: compat.requires_thinking_as_text,
109 requiresMistralToolIds: compat.requires_mistral_tool_ids,
110 thinkingFormat: compat.thinking_format,
111 };
112}