1package anthropic
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/anthropics/anthropic-sdk-go"
15 "github.com/anthropics/anthropic-sdk-go/option"
16 "github.com/anthropics/anthropic-sdk-go/packages/param"
17 "github.com/anthropics/anthropic-sdk-go/vertex"
18 "github.com/charmbracelet/fantasy/ai"
19 "golang.org/x/oauth2/google"
20)
21
22const (
23 Name = "anthropic"
24 DefaultURL = "https://api.anthropic.com"
25)
26
27type options struct {
28 baseURL string
29 apiKey string
30 name string
31 headers map[string]string
32 client option.HTTPClient
33
34 vertexProject string
35 vertexLocation string
36 skipGoogleAuth bool
37}
38
39type provider struct {
40 options options
41}
42
43type Option = func(*options)
44
45func New(opts ...Option) ai.Provider {
46 providerOptions := options{
47 headers: map[string]string{},
48 }
49 for _, o := range opts {
50 o(&providerOptions)
51 }
52
53 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
54 providerOptions.name = cmp.Or(providerOptions.name, Name)
55 return &provider{options: providerOptions}
56}
57
58func WithBaseURL(baseURL string) Option {
59 return func(o *options) {
60 o.baseURL = baseURL
61 }
62}
63
64func WithAPIKey(apiKey string) Option {
65 return func(o *options) {
66 o.apiKey = apiKey
67 }
68}
69
70func WithVertex(project, location string) Option {
71 return func(o *options) {
72 o.vertexProject = project
73 o.vertexLocation = location
74 }
75}
76
77func WithSkipGoogleAuth(skip bool) Option {
78 return func(o *options) {
79 o.skipGoogleAuth = skip
80 }
81}
82
83func WithName(name string) Option {
84 return func(o *options) {
85 o.name = name
86 }
87}
88
89func WithHeaders(headers map[string]string) Option {
90 return func(o *options) {
91 maps.Copy(o.headers, headers)
92 }
93}
94
95func WithHTTPClient(client option.HTTPClient) Option {
96 return func(o *options) {
97 o.client = client
98 }
99}
100
101func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
102 clientOptions := make([]option.RequestOption, 0, 5+len(a.options.headers))
103 if a.options.apiKey != "" {
104 clientOptions = append(clientOptions, option.WithAPIKey(a.options.apiKey))
105 }
106 if a.options.baseURL != "" {
107 clientOptions = append(clientOptions, option.WithBaseURL(a.options.baseURL))
108 }
109 for key, value := range a.options.headers {
110 clientOptions = append(clientOptions, option.WithHeader(key, value))
111 }
112 if a.options.client != nil {
113 clientOptions = append(clientOptions, option.WithHTTPClient(a.options.client))
114 }
115 if a.options.vertexProject != "" && a.options.vertexLocation != "" {
116 var credentials *google.Credentials
117 if a.options.skipGoogleAuth {
118 credentials = &google.Credentials{TokenSource: &googleDummyTokenSource{}}
119 } else {
120 var err error
121 credentials, err = google.FindDefaultCredentials(context.Background())
122 if err != nil {
123 return nil, err
124 }
125 }
126
127 clientOptions = append(
128 clientOptions,
129 vertex.WithCredentials(
130 context.Background(),
131 a.options.vertexLocation,
132 a.options.vertexProject,
133 credentials,
134 ),
135 )
136 }
137 return languageModel{
138 modelID: modelID,
139 provider: a.options.name,
140 options: a.options,
141 client: anthropic.NewClient(clientOptions...),
142 }, nil
143}
144
145type languageModel struct {
146 provider string
147 modelID string
148 client anthropic.Client
149 options options
150}
151
152// Model implements ai.LanguageModel.
153func (a languageModel) Model() string {
154 return a.modelID
155}
156
157// Provider implements ai.LanguageModel.
158func (a languageModel) Provider() string {
159 return a.provider
160}
161
162func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
163 params := &anthropic.MessageNewParams{}
164 providerOptions := &ProviderOptions{}
165 if v, ok := call.ProviderOptions[Name]; ok {
166 providerOptions, ok = v.(*ProviderOptions)
167 if !ok {
168 return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
169 }
170 }
171 sendReasoning := true
172 if providerOptions.SendReasoning != nil {
173 sendReasoning = *providerOptions.SendReasoning
174 }
175 systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning)
176
177 if call.FrequencyPenalty != nil {
178 warnings = append(warnings, ai.CallWarning{
179 Type: ai.CallWarningTypeUnsupportedSetting,
180 Setting: "FrequencyPenalty",
181 })
182 }
183 if call.PresencePenalty != nil {
184 warnings = append(warnings, ai.CallWarning{
185 Type: ai.CallWarningTypeUnsupportedSetting,
186 Setting: "PresencePenalty",
187 })
188 }
189
190 params.System = systemBlocks
191 params.Messages = messages
192 params.Model = anthropic.Model(a.modelID)
193 params.MaxTokens = 4096
194
195 if call.MaxOutputTokens != nil {
196 params.MaxTokens = *call.MaxOutputTokens
197 }
198
199 if call.Temperature != nil {
200 params.Temperature = param.NewOpt(*call.Temperature)
201 }
202 if call.TopK != nil {
203 params.TopK = param.NewOpt(*call.TopK)
204 }
205 if call.TopP != nil {
206 params.TopP = param.NewOpt(*call.TopP)
207 }
208
209 isThinking := false
210 var thinkingBudget int64
211 if providerOptions.Thinking != nil {
212 isThinking = true
213 thinkingBudget = providerOptions.Thinking.BudgetTokens
214 }
215 if isThinking {
216 if thinkingBudget == 0 {
217 return nil, nil, ai.NewUnsupportedFunctionalityError("thinking requires budget", "")
218 }
219 params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget)
220 if call.Temperature != nil {
221 params.Temperature = param.Opt[float64]{}
222 warnings = append(warnings, ai.CallWarning{
223 Type: ai.CallWarningTypeUnsupportedSetting,
224 Setting: "temperature",
225 Details: "temperature is not supported when thinking is enabled",
226 })
227 }
228 if call.TopP != nil {
229 params.TopP = param.Opt[float64]{}
230 warnings = append(warnings, ai.CallWarning{
231 Type: ai.CallWarningTypeUnsupportedSetting,
232 Setting: "TopP",
233 Details: "TopP is not supported when thinking is enabled",
234 })
235 }
236 if call.TopK != nil {
237 params.TopK = param.Opt[int64]{}
238 warnings = append(warnings, ai.CallWarning{
239 Type: ai.CallWarningTypeUnsupportedSetting,
240 Setting: "TopK",
241 Details: "TopK is not supported when thinking is enabled",
242 })
243 }
244 params.MaxTokens = params.MaxTokens + thinkingBudget
245 }
246
247 if len(call.Tools) > 0 {
248 disableParallelToolUse := false
249 if providerOptions.DisableParallelToolUse != nil {
250 disableParallelToolUse = *providerOptions.DisableParallelToolUse
251 }
252 tools, toolChoice, toolWarnings := toTools(call.Tools, call.ToolChoice, disableParallelToolUse)
253 params.Tools = tools
254 if toolChoice != nil {
255 params.ToolChoice = *toolChoice
256 }
257 warnings = append(warnings, toolWarnings...)
258 }
259
260 return params, warnings, nil
261}
262
263func (a *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
264 var options ProviderOptions
265 if err := ai.ParseOptions(data, &options); err != nil {
266 return nil, err
267 }
268 return &options, nil
269}
270
271func (a *provider) Name() string {
272 return Name
273}
274
275func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl {
276 if anthropicOptions, ok := providerOptions[Name]; ok {
277 if options, ok := anthropicOptions.(*ProviderCacheControlOptions); ok {
278 return &options.CacheControl
279 }
280 }
281 return nil
282}
283
284func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMetadata {
285 if anthropicOptions, ok := providerOptions[Name]; ok {
286 if reasoning, ok := anthropicOptions.(*ReasoningOptionMetadata); ok {
287 return reasoning
288 }
289 }
290 return nil
291}
292
293type messageBlock struct {
294 Role ai.MessageRole
295 Messages []ai.Message
296}
297
298func groupIntoBlocks(prompt ai.Prompt) []*messageBlock {
299 var blocks []*messageBlock
300
301 var currentBlock *messageBlock
302
303 for _, msg := range prompt {
304 switch msg.Role {
305 case ai.MessageRoleSystem:
306 if currentBlock == nil || currentBlock.Role != ai.MessageRoleSystem {
307 currentBlock = &messageBlock{
308 Role: ai.MessageRoleSystem,
309 Messages: []ai.Message{},
310 }
311 blocks = append(blocks, currentBlock)
312 }
313 currentBlock.Messages = append(currentBlock.Messages, msg)
314 case ai.MessageRoleUser:
315 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
316 currentBlock = &messageBlock{
317 Role: ai.MessageRoleUser,
318 Messages: []ai.Message{},
319 }
320 blocks = append(blocks, currentBlock)
321 }
322 currentBlock.Messages = append(currentBlock.Messages, msg)
323 case ai.MessageRoleAssistant:
324 if currentBlock == nil || currentBlock.Role != ai.MessageRoleAssistant {
325 currentBlock = &messageBlock{
326 Role: ai.MessageRoleAssistant,
327 Messages: []ai.Message{},
328 }
329 blocks = append(blocks, currentBlock)
330 }
331 currentBlock.Messages = append(currentBlock.Messages, msg)
332 case ai.MessageRoleTool:
333 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
334 currentBlock = &messageBlock{
335 Role: ai.MessageRoleUser,
336 Messages: []ai.Message{},
337 }
338 blocks = append(blocks, currentBlock)
339 }
340 currentBlock.Messages = append(currentBlock.Messages, msg)
341 }
342 }
343 return blocks
344}
345
346func toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) {
347 for _, tool := range tools {
348 if tool.GetType() == ai.ToolTypeFunction {
349 ft, ok := tool.(ai.FunctionTool)
350 if !ok {
351 continue
352 }
353 required := []string{}
354 var properties any
355 if props, ok := ft.InputSchema["properties"]; ok {
356 properties = props
357 }
358 if req, ok := ft.InputSchema["required"]; ok {
359 if reqArr, ok := req.([]string); ok {
360 required = reqArr
361 }
362 }
363 cacheControl := getCacheControl(ft.ProviderOptions)
364
365 anthropicTool := anthropic.ToolParam{
366 Name: ft.Name,
367 Description: anthropic.String(ft.Description),
368 InputSchema: anthropic.ToolInputSchemaParam{
369 Properties: properties,
370 Required: required,
371 },
372 }
373 if cacheControl != nil {
374 anthropicTool.CacheControl = anthropic.NewCacheControlEphemeralParam()
375 }
376 anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &anthropicTool})
377 continue
378 }
379 // TODO: handle provider tool calls
380 warnings = append(warnings, ai.CallWarning{
381 Type: ai.CallWarningTypeUnsupportedTool,
382 Tool: tool,
383 Message: "tool is not supported",
384 })
385 }
386 if toolChoice == nil {
387 if disableParallelToolCalls {
388 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
389 OfAuto: &anthropic.ToolChoiceAutoParam{
390 Type: "auto",
391 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
392 },
393 }
394 }
395 return anthropicTools, anthropicToolChoice, warnings
396 }
397
398 switch *toolChoice {
399 case ai.ToolChoiceAuto:
400 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
401 OfAuto: &anthropic.ToolChoiceAutoParam{
402 Type: "auto",
403 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
404 },
405 }
406 case ai.ToolChoiceRequired:
407 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
408 OfAny: &anthropic.ToolChoiceAnyParam{
409 Type: "any",
410 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
411 },
412 }
413 case ai.ToolChoiceNone:
414 return anthropicTools, anthropicToolChoice, warnings
415 default:
416 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
417 OfTool: &anthropic.ToolChoiceToolParam{
418 Type: "tool",
419 Name: string(*toolChoice),
420 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
421 },
422 }
423 }
424 return anthropicTools, anthropicToolChoice, warnings
425}
426
427func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) {
428 var systemBlocks []anthropic.TextBlockParam
429 var messages []anthropic.MessageParam
430 var warnings []ai.CallWarning
431
432 blocks := groupIntoBlocks(prompt)
433 finishedSystemBlock := false
434 for _, block := range blocks {
435 switch block.Role {
436 case ai.MessageRoleSystem:
437 if finishedSystemBlock {
438 // skip multiple system messages that are separated by user/assistant messages
439 // TODO: see if we need to send error here?
440 continue
441 }
442 finishedSystemBlock = true
443 for _, msg := range block.Messages {
444 for _, part := range msg.Content {
445 cacheControl := getCacheControl(part.Options())
446 text, ok := ai.AsMessagePart[ai.TextPart](part)
447 if !ok {
448 continue
449 }
450 textBlock := anthropic.TextBlockParam{
451 Text: text.Text,
452 }
453 if cacheControl != nil {
454 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
455 }
456 systemBlocks = append(systemBlocks, textBlock)
457 }
458 }
459
460 case ai.MessageRoleUser:
461 var anthropicContent []anthropic.ContentBlockParamUnion
462 for _, msg := range block.Messages {
463 if msg.Role == ai.MessageRoleUser {
464 for i, part := range msg.Content {
465 isLastPart := i == len(msg.Content)-1
466 cacheControl := getCacheControl(part.Options())
467 if cacheControl == nil && isLastPart {
468 cacheControl = getCacheControl(msg.ProviderOptions)
469 }
470 switch part.GetType() {
471 case ai.ContentTypeText:
472 text, ok := ai.AsMessagePart[ai.TextPart](part)
473 if !ok {
474 continue
475 }
476 textBlock := &anthropic.TextBlockParam{
477 Text: text.Text,
478 }
479 if cacheControl != nil {
480 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
481 }
482 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
483 OfText: textBlock,
484 })
485 case ai.ContentTypeFile:
486 file, ok := ai.AsMessagePart[ai.FilePart](part)
487 if !ok {
488 continue
489 }
490 // TODO: handle other file types
491 if !strings.HasPrefix(file.MediaType, "image/") {
492 continue
493 }
494
495 base64Encoded := base64.StdEncoding.EncodeToString(file.Data)
496 imageBlock := anthropic.NewImageBlockBase64(file.MediaType, base64Encoded)
497 if cacheControl != nil {
498 imageBlock.OfImage.CacheControl = anthropic.NewCacheControlEphemeralParam()
499 }
500 anthropicContent = append(anthropicContent, imageBlock)
501 }
502 }
503 } else if msg.Role == ai.MessageRoleTool {
504 for i, part := range msg.Content {
505 isLastPart := i == len(msg.Content)-1
506 cacheControl := getCacheControl(part.Options())
507 if cacheControl == nil && isLastPart {
508 cacheControl = getCacheControl(msg.ProviderOptions)
509 }
510 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
511 if !ok {
512 continue
513 }
514 toolResultBlock := anthropic.ToolResultBlockParam{
515 ToolUseID: result.ToolCallID,
516 }
517 switch result.Output.GetType() {
518 case ai.ToolResultContentTypeText:
519 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
520 if !ok {
521 continue
522 }
523 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
524 {
525 OfText: &anthropic.TextBlockParam{
526 Text: content.Text,
527 },
528 },
529 }
530 case ai.ToolResultContentTypeMedia:
531 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentMedia](result.Output)
532 if !ok {
533 continue
534 }
535 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
536 {
537 OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage,
538 },
539 }
540 case ai.ToolResultContentTypeError:
541 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
542 if !ok {
543 continue
544 }
545 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
546 {
547 OfText: &anthropic.TextBlockParam{
548 Text: content.Error.Error(),
549 },
550 },
551 }
552 toolResultBlock.IsError = param.NewOpt(true)
553 }
554 if cacheControl != nil {
555 toolResultBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
556 }
557 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
558 OfToolResult: &toolResultBlock,
559 })
560 }
561 }
562 }
563 messages = append(messages, anthropic.NewUserMessage(anthropicContent...))
564 case ai.MessageRoleAssistant:
565 var anthropicContent []anthropic.ContentBlockParamUnion
566 for _, msg := range block.Messages {
567 for i, part := range msg.Content {
568 isLastPart := i == len(msg.Content)-1
569 cacheControl := getCacheControl(part.Options())
570 if cacheControl == nil && isLastPart {
571 cacheControl = getCacheControl(msg.ProviderOptions)
572 }
573 switch part.GetType() {
574 case ai.ContentTypeText:
575 text, ok := ai.AsMessagePart[ai.TextPart](part)
576 if !ok {
577 continue
578 }
579 textBlock := &anthropic.TextBlockParam{
580 Text: text.Text,
581 }
582 if cacheControl != nil {
583 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
584 }
585 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
586 OfText: textBlock,
587 })
588 case ai.ContentTypeReasoning:
589 reasoning, ok := ai.AsMessagePart[ai.ReasoningPart](part)
590 if !ok {
591 continue
592 }
593 if !sendReasoningData {
594 warnings = append(warnings, ai.CallWarning{
595 Type: "other",
596 Message: "sending reasoning content is disabled for this model",
597 })
598 continue
599 }
600 reasoningMetadata := getReasoningMetadata(part.Options())
601 if reasoningMetadata == nil {
602 warnings = append(warnings, ai.CallWarning{
603 Type: "other",
604 Message: "unsupported reasoning metadata",
605 })
606 continue
607 }
608
609 if reasoningMetadata.Signature != "" {
610 anthropicContent = append(anthropicContent, anthropic.NewThinkingBlock(reasoningMetadata.Signature, reasoning.Text))
611 } else if reasoningMetadata.RedactedData != "" {
612 anthropicContent = append(anthropicContent, anthropic.NewRedactedThinkingBlock(reasoningMetadata.RedactedData))
613 } else {
614 warnings = append(warnings, ai.CallWarning{
615 Type: "other",
616 Message: "unsupported reasoning metadata",
617 })
618 continue
619 }
620 case ai.ContentTypeToolCall:
621 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
622 if !ok {
623 continue
624 }
625 if toolCall.ProviderExecuted {
626 // TODO: implement provider executed call
627 continue
628 }
629
630 var inputMap map[string]any
631 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
632 if err != nil {
633 continue
634 }
635 toolUseBlock := anthropic.NewToolUseBlock(toolCall.ToolCallID, inputMap, toolCall.ToolName)
636 if cacheControl != nil {
637 toolUseBlock.OfToolUse.CacheControl = anthropic.NewCacheControlEphemeralParam()
638 }
639 anthropicContent = append(anthropicContent, toolUseBlock)
640 case ai.ContentTypeToolResult:
641 // TODO: implement provider executed tool result
642 }
643 }
644 }
645 messages = append(messages, anthropic.NewAssistantMessage(anthropicContent...))
646 }
647 }
648 return systemBlocks, messages, warnings
649}
650
651func (o languageModel) handleError(err error) error {
652 var apiErr *anthropic.Error
653 if errors.As(err, &apiErr) {
654 requestDump := apiErr.DumpRequest(true)
655 responseDump := apiErr.DumpResponse(true)
656 headers := map[string]string{}
657 for k, h := range apiErr.Response.Header {
658 v := h[len(h)-1]
659 headers[strings.ToLower(k)] = v
660 }
661 return ai.NewAPICallError(
662 apiErr.Error(),
663 apiErr.Request.URL.String(),
664 string(requestDump),
665 apiErr.StatusCode,
666 headers,
667 string(responseDump),
668 apiErr,
669 false,
670 )
671 }
672 return err
673}
674
675func mapFinishReason(finishReason string) ai.FinishReason {
676 switch finishReason {
677 case "end_turn", "pause_turn", "stop_sequence":
678 return ai.FinishReasonStop
679 case "max_tokens":
680 return ai.FinishReasonLength
681 case "tool_use":
682 return ai.FinishReasonToolCalls
683 default:
684 return ai.FinishReasonUnknown
685 }
686}
687
688// Generate implements ai.LanguageModel.
689func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
690 params, warnings, err := a.prepareParams(call)
691 if err != nil {
692 return nil, err
693 }
694 response, err := a.client.Messages.New(ctx, *params)
695 if err != nil {
696 return nil, a.handleError(err)
697 }
698
699 var content []ai.Content
700 for _, block := range response.Content {
701 switch block.Type {
702 case "text":
703 text, ok := block.AsAny().(anthropic.TextBlock)
704 if !ok {
705 continue
706 }
707 content = append(content, ai.TextContent{
708 Text: text.Text,
709 })
710 case "thinking":
711 reasoning, ok := block.AsAny().(anthropic.ThinkingBlock)
712 if !ok {
713 continue
714 }
715 content = append(content, ai.ReasoningContent{
716 Text: reasoning.Thinking,
717 ProviderMetadata: ai.ProviderMetadata{
718 Name: &ReasoningOptionMetadata{
719 Signature: reasoning.Signature,
720 },
721 },
722 })
723 case "redacted_thinking":
724 reasoning, ok := block.AsAny().(anthropic.RedactedThinkingBlock)
725 if !ok {
726 continue
727 }
728 content = append(content, ai.ReasoningContent{
729 Text: "",
730 ProviderMetadata: ai.ProviderMetadata{
731 Name: &ReasoningOptionMetadata{
732 RedactedData: reasoning.Data,
733 },
734 },
735 })
736 case "tool_use":
737 toolUse, ok := block.AsAny().(anthropic.ToolUseBlock)
738 if !ok {
739 continue
740 }
741 content = append(content, ai.ToolCallContent{
742 ToolCallID: toolUse.ID,
743 ToolName: toolUse.Name,
744 Input: string(toolUse.Input),
745 ProviderExecuted: false,
746 })
747 }
748 }
749
750 return &ai.Response{
751 Content: content,
752 Usage: ai.Usage{
753 InputTokens: response.Usage.InputTokens,
754 OutputTokens: response.Usage.OutputTokens,
755 TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
756 CacheCreationTokens: response.Usage.CacheCreationInputTokens,
757 CacheReadTokens: response.Usage.CacheReadInputTokens,
758 },
759 FinishReason: mapFinishReason(string(response.StopReason)),
760 ProviderMetadata: ai.ProviderMetadata{},
761 Warnings: warnings,
762 }, nil
763}
764
765// Stream implements ai.LanguageModel.
766func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
767 params, warnings, err := a.prepareParams(call)
768 if err != nil {
769 return nil, err
770 }
771
772 stream := a.client.Messages.NewStreaming(ctx, *params)
773 acc := anthropic.Message{}
774 return func(yield func(ai.StreamPart) bool) {
775 if len(warnings) > 0 {
776 if !yield(ai.StreamPart{
777 Type: ai.StreamPartTypeWarnings,
778 Warnings: warnings,
779 }) {
780 return
781 }
782 }
783
784 for stream.Next() {
785 chunk := stream.Current()
786 _ = acc.Accumulate(chunk)
787 switch chunk.Type {
788 case "content_block_start":
789 contentBlockType := chunk.ContentBlock.Type
790 switch contentBlockType {
791 case "text":
792 if !yield(ai.StreamPart{
793 Type: ai.StreamPartTypeTextStart,
794 ID: fmt.Sprintf("%d", chunk.Index),
795 }) {
796 return
797 }
798 case "thinking":
799 if !yield(ai.StreamPart{
800 Type: ai.StreamPartTypeReasoningStart,
801 ID: fmt.Sprintf("%d", chunk.Index),
802 }) {
803 return
804 }
805 case "redacted_thinking":
806 if !yield(ai.StreamPart{
807 Type: ai.StreamPartTypeReasoningStart,
808 ID: fmt.Sprintf("%d", chunk.Index),
809 ProviderMetadata: ai.ProviderMetadata{
810 Name: &ReasoningOptionMetadata{
811 RedactedData: chunk.ContentBlock.Data,
812 },
813 },
814 }) {
815 return
816 }
817 case "tool_use":
818 if !yield(ai.StreamPart{
819 Type: ai.StreamPartTypeToolInputStart,
820 ID: chunk.ContentBlock.ID,
821 ToolCallName: chunk.ContentBlock.Name,
822 ToolCallInput: "",
823 }) {
824 return
825 }
826 }
827 case "content_block_stop":
828 if len(acc.Content)-1 < int(chunk.Index) {
829 continue
830 }
831 contentBlock := acc.Content[int(chunk.Index)]
832 switch contentBlock.Type {
833 case "text":
834 if !yield(ai.StreamPart{
835 Type: ai.StreamPartTypeTextEnd,
836 ID: fmt.Sprintf("%d", chunk.Index),
837 }) {
838 return
839 }
840 case "thinking":
841 if !yield(ai.StreamPart{
842 Type: ai.StreamPartTypeReasoningEnd,
843 ID: fmt.Sprintf("%d", chunk.Index),
844 }) {
845 return
846 }
847 case "tool_use":
848 if !yield(ai.StreamPart{
849 Type: ai.StreamPartTypeToolInputEnd,
850 ID: contentBlock.ID,
851 }) {
852 return
853 }
854 if !yield(ai.StreamPart{
855 Type: ai.StreamPartTypeToolCall,
856 ID: contentBlock.ID,
857 ToolCallName: contentBlock.Name,
858 ToolCallInput: string(contentBlock.Input),
859 }) {
860 return
861 }
862 }
863 case "content_block_delta":
864 switch chunk.Delta.Type {
865 case "text_delta":
866 if !yield(ai.StreamPart{
867 Type: ai.StreamPartTypeTextDelta,
868 ID: fmt.Sprintf("%d", chunk.Index),
869 Delta: chunk.Delta.Text,
870 }) {
871 return
872 }
873 case "thinking_delta":
874 if !yield(ai.StreamPart{
875 Type: ai.StreamPartTypeReasoningDelta,
876 ID: fmt.Sprintf("%d", chunk.Index),
877 Delta: chunk.Delta.Thinking,
878 }) {
879 return
880 }
881 case "signature_delta":
882 if !yield(ai.StreamPart{
883 Type: ai.StreamPartTypeReasoningDelta,
884 ID: fmt.Sprintf("%d", chunk.Index),
885 ProviderMetadata: ai.ProviderMetadata{
886 Name: &ReasoningOptionMetadata{
887 Signature: chunk.Delta.Signature,
888 },
889 },
890 }) {
891 return
892 }
893 case "input_json_delta":
894 if len(acc.Content)-1 < int(chunk.Index) {
895 continue
896 }
897 contentBlock := acc.Content[int(chunk.Index)]
898 if !yield(ai.StreamPart{
899 Type: ai.StreamPartTypeToolInputDelta,
900 ID: contentBlock.ID,
901 ToolCallInput: chunk.Delta.PartialJSON,
902 }) {
903 return
904 }
905 }
906 case "message_stop":
907 }
908 }
909
910 err := stream.Err()
911 if err == nil || errors.Is(err, io.EOF) {
912 yield(ai.StreamPart{
913 Type: ai.StreamPartTypeFinish,
914 ID: acc.ID,
915 FinishReason: mapFinishReason(string(acc.StopReason)),
916 Usage: ai.Usage{
917 InputTokens: acc.Usage.InputTokens,
918 OutputTokens: acc.Usage.OutputTokens,
919 TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
920 CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
921 CacheReadTokens: acc.Usage.CacheReadInputTokens,
922 },
923 ProviderMetadata: ai.ProviderMetadata{},
924 })
925 return
926 } else {
927 yield(ai.StreamPart{
928 Type: ai.StreamPartTypeError,
929 Error: a.handleError(err),
930 })
931 return
932 }
933 }, nil
934}