1// Package llm provides a unified interface for interacting with LLMs.
2package llm
3
4import (
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "os"
11 "path/filepath"
12 "strconv"
13 "strings"
14 "time"
15)
16
17type Service interface {
18 // Do sends a request to an LLM.
19 Do(context.Context, *Request) (*Response, error)
20 // TokenContextWindow returns the maximum token context window size for this service
21 TokenContextWindow() int
22 // MaxImageDimension returns the maximum allowed dimension (width or height) for images.
23 // For multi-image requests, some providers enforce stricter limits.
24 // Returns 0 if there is no limit.
25 MaxImageDimension() int
26}
27
28type SimplifiedPatcher interface {
29 // UseSimplifiedPatch reports whether the service should use the simplified patch input schema.
30 UseSimplifiedPatch() bool
31}
32
33func UseSimplifiedPatch(svc Service) bool {
34 if sp, ok := svc.(SimplifiedPatcher); ok {
35 return sp.UseSimplifiedPatch()
36 }
37 return false
38}
39
40// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
41// It panics if the schema is invalid.
42// The schema must have at least type="object" and a properties key.
43func MustSchema(schema string) json.RawMessage {
44 schema = strings.TrimSpace(schema)
45 bytes := []byte(schema)
46 var obj map[string]any
47 if err := json.Unmarshal(bytes, &obj); err != nil {
48 panic("failed to parse JSON schema: " + schema + ": " + err.Error())
49 }
50 if typ, ok := obj["type"]; !ok || typ != "object" {
51 panic("JSON schema must have type='object': " + schema)
52 }
53 if _, ok := obj["properties"]; !ok {
54 panic("JSON schema must have 'properties' key: " + schema)
55 }
56 return json.RawMessage(bytes)
57}
58
59func EmptySchema() json.RawMessage {
60 return MustSchema(`{"type": "object", "properties": {}}`)
61}
62
63// ErrorType identifies system-generated error messages (not LLM content).
64type ErrorType string
65
66const (
67 ErrorTypeNone ErrorType = "" // Not an error
68 ErrorTypeTruncation ErrorType = "truncation" // Response truncated due to max tokens
69 ErrorTypeLLMRequest ErrorType = "llm_request" // LLM request failed
70)
71
72type Request struct {
73 Messages []Message
74 ToolChoice *ToolChoice
75 Tools []*Tool
76 System []SystemContent
77}
78
79// Message represents a message in the conversation.
80type Message struct {
81 Role MessageRole `json:"Role"`
82 Content []Content `json:"Content"`
83 ToolUse *ToolUse `json:"ToolUse,omitempty"` // use to control whether/which tool to use
84 EndOfTurn bool `json:"EndOfTurn"` // true if this message completes the agent's turn (no tool calls to make)
85
86 // ExcludedFromContext indicates this message should be stored but not sent back to the LLM.
87 // Used for truncated responses we want to keep for cost tracking but that would confuse the LLM.
88 ExcludedFromContext bool `json:"ExcludedFromContext,omitempty"`
89
90 // ErrorType indicates this is a system-generated error message (not LLM content).
91 // Empty string means not an error. Values: "truncation", "llm_request".
92 ErrorType ErrorType `json:"ErrorType,omitempty"`
93}
94
95// ToolUse represents a tool use in the message content.
96type ToolUse struct {
97 ID string
98 Name string
99}
100
101type ToolChoice struct {
102 Type ToolChoiceType
103 Name string
104}
105
106type SystemContent struct {
107 Text string
108 Type string
109 Cache bool
110}
111
112// Tool represents a tool available to an LLM.
113type Tool struct {
114 Name string
115 // Type is used by the text editor tool; see
116 // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
117 Type string
118 Description string
119 InputSchema json.RawMessage
120 // EndsTurn indicates that this tool should cause the model to end its turn when used
121 EndsTurn bool
122 // Cache indicates whether to use prompt caching for this tool
123 Cache bool
124
125 // The Run function is automatically called when the tool is used.
126 // Run functions may be called concurrently with each other and themselves.
127 // The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
128 // The outputs from Run will be sent back to Claude.
129 // If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
130 // ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
131 Run func(ctx context.Context, input json.RawMessage) ToolOut `json:"-"`
132}
133
134// ToolOut represents the output of a tool run.
135type ToolOut struct {
136 // LLMContent is the output of the tool to be sent back to the LLM.
137 // May be nil on error.
138 LLMContent []Content
139 // Display is content to be displayed to the user.
140 // The type of content is set by the tool and coordinated with the UIs.
141 // It should be JSON-serializable.
142 Display any
143 // Error is the error (if any) that occurred during the tool run.
144 // The text contents of the error will be sent back to the LLM.
145 // If non-nil, LLMContent will be ignored.
146 Error error
147}
148
149type Content struct {
150 ID string
151 Type ContentType
152 Text string
153
154 // Media type for image content
155 MediaType string
156
157 // for thinking
158 Thinking string
159 Data string
160 Signature string
161
162 // for tool_use
163 ToolName string
164 ToolInput json.RawMessage
165
166 // for tool_result
167 ToolUseID string
168 ToolError bool
169 ToolResult []Content
170
171 // timing information for tool_result; added externally; not sent to the LLM
172 ToolUseStartTime *time.Time
173 ToolUseEndTime *time.Time
174
175 // Display is content to be displayed to the user, copied from ToolOut
176 Display any
177
178 Cache bool
179}
180
181func StringContent(s string) Content {
182 return Content{Type: ContentTypeText, Text: s}
183}
184
185// ContentsAttr returns contents as a slog.Attr.
186// It is meant for logging.
187func ContentsAttr(contents []Content) slog.Attr {
188 var contentAttrs []any // slog.Attr
189 for _, content := range contents {
190 var attrs []any // slog.Attr
191 switch content.Type {
192 case ContentTypeText:
193 attrs = append(attrs, slog.String("text", content.Text))
194 case ContentTypeToolUse:
195 attrs = append(attrs, slog.String("tool_name", content.ToolName))
196 attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
197 case ContentTypeToolResult:
198 attrs = append(attrs, slog.Any("tool_result", content.ToolResult))
199 attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
200 case ContentTypeThinking:
201 attrs = append(attrs, slog.String("thinking", content.Text))
202 default:
203 attrs = append(attrs, slog.String("unknown_content_type", content.Type.String()))
204 attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
205 }
206 contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
207 }
208 return slog.Group("contents", contentAttrs...)
209}
210
211type (
212 MessageRole int
213 ContentType int
214 ToolChoiceType int
215 StopReason int
216)
217
218//go:generate go tool golang.org/x/tools/cmd/stringer -type=MessageRole,ContentType,ToolChoiceType,StopReason -output=llm_string.go
219
220const (
221 MessageRoleUser MessageRole = iota
222 MessageRoleAssistant
223
224 ContentTypeText ContentType = iota
225 ContentTypeThinking
226 ContentTypeRedactedThinking
227 ContentTypeToolUse
228 ContentTypeToolResult
229
230 ToolChoiceTypeAuto ToolChoiceType = iota // default
231 ToolChoiceTypeAny // any tool, but must use one
232 ToolChoiceTypeNone // no tools allowed
233 ToolChoiceTypeTool // must use the tool specified in the Name field
234
235 StopReasonStopSequence StopReason = iota
236 StopReasonMaxTokens
237 StopReasonEndTurn
238 StopReasonToolUse
239 StopReasonRefusal
240)
241
242type Response struct {
243 ID string
244 Type string
245 Role MessageRole
246 Model string
247 Content []Content
248 StopReason StopReason
249 StopSequence *string
250 Usage Usage
251 StartTime *time.Time
252 EndTime *time.Time
253}
254
255func (m *Response) ToMessage() Message {
256 return Message{
257 Role: m.Role,
258 Content: m.Content,
259 EndOfTurn: m.StopReason != StopReasonToolUse, // End of turn unless there are tools to call
260 }
261}
262
263func CostUSDFromResponse(headers http.Header) float64 {
264 h := headers.Get("Skaband-Cost-Microcents")
265 if h == "" {
266 return 0
267 }
268 uc, err := strconv.ParseUint(h, 10, 64)
269 if err != nil {
270 slog.Warn("failed to parse cost header", "header", h)
271 return 0
272 }
273 return float64(uc) / 100_000_000
274}
275
276// Usage represents the billing and rate-limit usage.
277// Most LLM structs do not have JSON tags, to avoid accidental direct use in specific providers.
278// However, the front-end uses this struct, and it relies on its JSON serialization.
279// Do NOT use this struct directly when implementing an llm.Service.
280type Usage struct {
281 InputTokens uint64 `json:"input_tokens"`
282 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
283 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
284 OutputTokens uint64 `json:"output_tokens"`
285 CostUSD float64 `json:"cost_usd"`
286 Model string `json:"model,omitempty"`
287 StartTime *time.Time `json:"start_time,omitempty"`
288 EndTime *time.Time `json:"end_time,omitempty"`
289}
290
291func (u *Usage) Add(other Usage) {
292 u.InputTokens += other.InputTokens
293 u.CacheCreationInputTokens += other.CacheCreationInputTokens
294 u.CacheReadInputTokens += other.CacheReadInputTokens
295 u.OutputTokens += other.OutputTokens
296 u.CostUSD += other.CostUSD
297}
298
299func (u *Usage) String() string {
300 return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
301}
302
303// TotalInputTokens returns the total number of input tokens including cached tokens.
304// This represents the full context that was sent to the model:
305// - InputTokens: tokens processed (not from cache)
306// - CacheCreationInputTokens: tokens written to cache (also part of input)
307// - CacheReadInputTokens: tokens read from cache (also part of input)
308func (u *Usage) TotalInputTokens() uint64 {
309 return u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens
310}
311
312// ContextWindowUsed returns the total context window usage after this response.
313// This is the size of the conversation that would be sent to the model for the next turn:
314// total input tokens + output tokens (which become part of the conversation).
315func (u *Usage) ContextWindowUsed() uint64 {
316 return u.TotalInputTokens() + u.OutputTokens
317}
318
319func (u *Usage) IsZero() bool {
320 return *u == Usage{}
321}
322
323func (u *Usage) Attr() slog.Attr {
324 return slog.Group("usage",
325 slog.Uint64("input_tokens", u.InputTokens),
326 slog.Uint64("output_tokens", u.OutputTokens),
327 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
328 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
329 slog.Float64("cost_usd", u.CostUSD),
330 )
331}
332
333// UserStringMessage creates a user message with a single text content item.
334func UserStringMessage(text string) Message {
335 return Message{
336 Role: MessageRoleUser,
337 Content: []Content{StringContent(text)},
338 }
339}
340
341// TextContent creates a simple text content for tool results.
342// This is a helper function to create the most common type of tool result content.
343func TextContent(text string) []Content {
344 return []Content{{
345 Type: ContentTypeText,
346 Text: text,
347 }}
348}
349
350func ErrorToolOut(err error) ToolOut {
351 if err == nil {
352 panic("ErrorToolOut called with nil error")
353 }
354 return ToolOut{
355 Error: err,
356 }
357}
358
359func ErrorfToolOut(format string, args ...any) ToolOut {
360 return ErrorToolOut(fmt.Errorf(format, args...))
361}
362
363// DumpToFile writes LLM communication content to a timestamped file in ~/.cache/sketch/.
364// For requests, it includes the URL followed by the content. For responses, it only includes the content.
365// The typ parameter is used as a prefix in the filename ("request", "response").
366func DumpToFile(typ, url string, content []byte) error {
367 homeDir, err := os.UserHomeDir()
368 if err != nil {
369 return err
370 }
371 cacheDir := filepath.Join(homeDir, ".cache", "sketch")
372 err = os.MkdirAll(cacheDir, 0o700)
373 if err != nil {
374 return err
375 }
376 now := time.Now()
377 filename := fmt.Sprintf("%s_%d.txt", typ, now.UnixMilli())
378 filePath := filepath.Join(cacheDir, filename)
379
380 // For requests, start with the URL; for responses, just write the content
381 data := []byte(url)
382 if url != "" {
383 data = append(data, "\n\n"...)
384 }
385 data = append(data, content...)
386
387 return os.WriteFile(filePath, data, 0o600)
388}