llm.go

  1// Package llm provides a unified interface for interacting with LLMs.
  2package llm
  3
  4import (
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"net/http"
 10	"os"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14	"time"
 15)
 16
 17type Service interface {
 18	// Do sends a request to an LLM.
 19	Do(context.Context, *Request) (*Response, error)
 20	// TokenContextWindow returns the maximum token context window size for this service
 21	TokenContextWindow() int
 22	// MaxImageDimension returns the maximum allowed dimension (width or height) for images.
 23	// For multi-image requests, some providers enforce stricter limits.
 24	// Returns 0 if there is no limit.
 25	MaxImageDimension() int
 26}
 27
 28type SimplifiedPatcher interface {
 29	// UseSimplifiedPatch reports whether the service should use the simplified patch input schema.
 30	UseSimplifiedPatch() bool
 31}
 32
 33func UseSimplifiedPatch(svc Service) bool {
 34	if sp, ok := svc.(SimplifiedPatcher); ok {
 35		return sp.UseSimplifiedPatch()
 36	}
 37	return false
 38}
 39
 40// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
 41// It panics if the schema is invalid.
 42// The schema must have at least type="object" and a properties key.
 43func MustSchema(schema string) json.RawMessage {
 44	schema = strings.TrimSpace(schema)
 45	bytes := []byte(schema)
 46	var obj map[string]any
 47	if err := json.Unmarshal(bytes, &obj); err != nil {
 48		panic("failed to parse JSON schema: " + schema + ": " + err.Error())
 49	}
 50	if typ, ok := obj["type"]; !ok || typ != "object" {
 51		panic("JSON schema must have type='object': " + schema)
 52	}
 53	if _, ok := obj["properties"]; !ok {
 54		panic("JSON schema must have 'properties' key: " + schema)
 55	}
 56	return json.RawMessage(bytes)
 57}
 58
 59func EmptySchema() json.RawMessage {
 60	return MustSchema(`{"type": "object", "properties": {}}`)
 61}
 62
 63// ErrorType identifies system-generated error messages (not LLM content).
 64type ErrorType string
 65
 66const (
 67	ErrorTypeNone       ErrorType = ""            // Not an error
 68	ErrorTypeTruncation ErrorType = "truncation"  // Response truncated due to max tokens
 69	ErrorTypeLLMRequest ErrorType = "llm_request" // LLM request failed
 70)
 71
 72type Request struct {
 73	Messages   []Message
 74	ToolChoice *ToolChoice
 75	Tools      []*Tool
 76	System     []SystemContent
 77}
 78
 79// Message represents a message in the conversation.
 80type Message struct {
 81	Role      MessageRole `json:"Role"`
 82	Content   []Content   `json:"Content"`
 83	ToolUse   *ToolUse    `json:"ToolUse,omitempty"` // use to control whether/which tool to use
 84	EndOfTurn bool        `json:"EndOfTurn"`         // true if this message completes the agent's turn (no tool calls to make)
 85
 86	// ExcludedFromContext indicates this message should be stored but not sent back to the LLM.
 87	// Used for truncated responses we want to keep for cost tracking but that would confuse the LLM.
 88	ExcludedFromContext bool `json:"ExcludedFromContext,omitempty"`
 89
 90	// ErrorType indicates this is a system-generated error message (not LLM content).
 91	// Empty string means not an error. Values: "truncation", "llm_request".
 92	ErrorType ErrorType `json:"ErrorType,omitempty"`
 93}
 94
 95// ToolUse represents a tool use in the message content.
 96type ToolUse struct {
 97	ID   string
 98	Name string
 99}
100
101type ToolChoice struct {
102	Type ToolChoiceType
103	Name string
104}
105
106type SystemContent struct {
107	Text  string
108	Type  string
109	Cache bool
110}
111
112// Tool represents a tool available to an LLM.
113type Tool struct {
114	Name string
115	// Type is used by the text editor tool; see
116	// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
117	Type        string
118	Description string
119	InputSchema json.RawMessage
120	// EndsTurn indicates that this tool should cause the model to end its turn when used
121	EndsTurn bool
122	// Cache indicates whether to use prompt caching for this tool
123	Cache bool
124
125	// The Run function is automatically called when the tool is used.
126	// Run functions may be called concurrently with each other and themselves.
127	// The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
128	// The outputs from Run will be sent back to Claude.
129	// If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
130	// ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
131	Run func(ctx context.Context, input json.RawMessage) ToolOut `json:"-"`
132}
133
134// ToolOut represents the output of a tool run.
135type ToolOut struct {
136	// LLMContent is the output of the tool to be sent back to the LLM.
137	// May be nil on error.
138	LLMContent []Content
139	// Display is content to be displayed to the user.
140	// The type of content is set by the tool and coordinated with the UIs.
141	// It should be JSON-serializable.
142	Display any
143	// Error is the error (if any) that occurred during the tool run.
144	// The text contents of the error will be sent back to the LLM.
145	// If non-nil, LLMContent will be ignored.
146	Error error
147}
148
149type Content struct {
150	ID   string
151	Type ContentType
152	Text string
153
154	// Media type for image content
155	MediaType string
156
157	// for thinking
158	Thinking  string
159	Data      string
160	Signature string
161
162	// for tool_use
163	ToolName  string
164	ToolInput json.RawMessage
165
166	// for tool_result
167	ToolUseID  string
168	ToolError  bool
169	ToolResult []Content
170
171	// timing information for tool_result; added externally; not sent to the LLM
172	ToolUseStartTime *time.Time
173	ToolUseEndTime   *time.Time
174
175	// Display is content to be displayed to the user, copied from ToolOut
176	Display any
177
178	Cache bool
179}
180
181func StringContent(s string) Content {
182	return Content{Type: ContentTypeText, Text: s}
183}
184
185// ContentsAttr returns contents as a slog.Attr.
186// It is meant for logging.
187func ContentsAttr(contents []Content) slog.Attr {
188	var contentAttrs []any // slog.Attr
189	for _, content := range contents {
190		var attrs []any // slog.Attr
191		switch content.Type {
192		case ContentTypeText:
193			attrs = append(attrs, slog.String("text", content.Text))
194		case ContentTypeToolUse:
195			attrs = append(attrs, slog.String("tool_name", content.ToolName))
196			attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
197		case ContentTypeToolResult:
198			attrs = append(attrs, slog.Any("tool_result", content.ToolResult))
199			attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
200		case ContentTypeThinking:
201			attrs = append(attrs, slog.String("thinking", content.Text))
202		default:
203			attrs = append(attrs, slog.String("unknown_content_type", content.Type.String()))
204			attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
205		}
206		contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
207	}
208	return slog.Group("contents", contentAttrs...)
209}
210
211type (
212	MessageRole    int
213	ContentType    int
214	ToolChoiceType int
215	StopReason     int
216)
217
218//go:generate go tool golang.org/x/tools/cmd/stringer -type=MessageRole,ContentType,ToolChoiceType,StopReason -output=llm_string.go
219
220const (
221	MessageRoleUser MessageRole = iota
222	MessageRoleAssistant
223
224	ContentTypeText ContentType = iota
225	ContentTypeThinking
226	ContentTypeRedactedThinking
227	ContentTypeToolUse
228	ContentTypeToolResult
229
230	ToolChoiceTypeAuto ToolChoiceType = iota // default
231	ToolChoiceTypeAny                        // any tool, but must use one
232	ToolChoiceTypeNone                       // no tools allowed
233	ToolChoiceTypeTool                       // must use the tool specified in the Name field
234
235	StopReasonStopSequence StopReason = iota
236	StopReasonMaxTokens
237	StopReasonEndTurn
238	StopReasonToolUse
239	StopReasonRefusal
240)
241
242type Response struct {
243	ID           string
244	Type         string
245	Role         MessageRole
246	Model        string
247	Content      []Content
248	StopReason   StopReason
249	StopSequence *string
250	Usage        Usage
251	StartTime    *time.Time
252	EndTime      *time.Time
253}
254
255func (m *Response) ToMessage() Message {
256	return Message{
257		Role:      m.Role,
258		Content:   m.Content,
259		EndOfTurn: m.StopReason != StopReasonToolUse, // End of turn unless there are tools to call
260	}
261}
262
263func CostUSDFromResponse(headers http.Header) float64 {
264	h := headers.Get("Skaband-Cost-Microcents")
265	if h == "" {
266		return 0
267	}
268	uc, err := strconv.ParseUint(h, 10, 64)
269	if err != nil {
270		slog.Warn("failed to parse cost header", "header", h)
271		return 0
272	}
273	return float64(uc) / 100_000_000
274}
275
276// Usage represents the billing and rate-limit usage.
277// Most LLM structs do not have JSON tags, to avoid accidental direct use in specific providers.
278// However, the front-end uses this struct, and it relies on its JSON serialization.
279// Do NOT use this struct directly when implementing an llm.Service.
280type Usage struct {
281	InputTokens              uint64     `json:"input_tokens"`
282	CacheCreationInputTokens uint64     `json:"cache_creation_input_tokens"`
283	CacheReadInputTokens     uint64     `json:"cache_read_input_tokens"`
284	OutputTokens             uint64     `json:"output_tokens"`
285	CostUSD                  float64    `json:"cost_usd"`
286	Model                    string     `json:"model,omitempty"`
287	StartTime                *time.Time `json:"start_time,omitempty"`
288	EndTime                  *time.Time `json:"end_time,omitempty"`
289}
290
291func (u *Usage) Add(other Usage) {
292	u.InputTokens += other.InputTokens
293	u.CacheCreationInputTokens += other.CacheCreationInputTokens
294	u.CacheReadInputTokens += other.CacheReadInputTokens
295	u.OutputTokens += other.OutputTokens
296	u.CostUSD += other.CostUSD
297}
298
299func (u *Usage) String() string {
300	return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
301}
302
303// TotalInputTokens returns the total number of input tokens including cached tokens.
304// This represents the full context that was sent to the model:
305// - InputTokens: tokens processed (not from cache)
306// - CacheCreationInputTokens: tokens written to cache (also part of input)
307// - CacheReadInputTokens: tokens read from cache (also part of input)
308func (u *Usage) TotalInputTokens() uint64 {
309	return u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens
310}
311
312// ContextWindowUsed returns the total context window usage after this response.
313// This is the size of the conversation that would be sent to the model for the next turn:
314// total input tokens + output tokens (which become part of the conversation).
315func (u *Usage) ContextWindowUsed() uint64 {
316	return u.TotalInputTokens() + u.OutputTokens
317}
318
319func (u *Usage) IsZero() bool {
320	return *u == Usage{}
321}
322
323func (u *Usage) Attr() slog.Attr {
324	return slog.Group("usage",
325		slog.Uint64("input_tokens", u.InputTokens),
326		slog.Uint64("output_tokens", u.OutputTokens),
327		slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
328		slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
329		slog.Float64("cost_usd", u.CostUSD),
330	)
331}
332
333// UserStringMessage creates a user message with a single text content item.
334func UserStringMessage(text string) Message {
335	return Message{
336		Role:    MessageRoleUser,
337		Content: []Content{StringContent(text)},
338	}
339}
340
341// TextContent creates a simple text content for tool results.
342// This is a helper function to create the most common type of tool result content.
343func TextContent(text string) []Content {
344	return []Content{{
345		Type: ContentTypeText,
346		Text: text,
347	}}
348}
349
350func ErrorToolOut(err error) ToolOut {
351	if err == nil {
352		panic("ErrorToolOut called with nil error")
353	}
354	return ToolOut{
355		Error: err,
356	}
357}
358
359func ErrorfToolOut(format string, args ...any) ToolOut {
360	return ErrorToolOut(fmt.Errorf(format, args...))
361}
362
363// DumpToFile writes LLM communication content to a timestamped file in ~/.cache/sketch/.
364// For requests, it includes the URL followed by the content. For responses, it only includes the content.
365// The typ parameter is used as a prefix in the filename ("request", "response").
366func DumpToFile(typ, url string, content []byte) error {
367	homeDir, err := os.UserHomeDir()
368	if err != nil {
369		return err
370	}
371	cacheDir := filepath.Join(homeDir, ".cache", "sketch")
372	err = os.MkdirAll(cacheDir, 0o700)
373	if err != nil {
374		return err
375	}
376	now := time.Now()
377	filename := fmt.Sprintf("%s_%d.txt", typ, now.UnixMilli())
378	filePath := filepath.Join(cacheDir, filename)
379
380	// For requests, start with the URL; for responses, just write the content
381	data := []byte(url)
382	if url != "" {
383		data = append(data, "\n\n"...)
384	}
385	data = append(data, content...)
386
387	return os.WriteFile(filePath, data, 0o600)
388}