1package anthropic
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/anthropics/anthropic-sdk-go"
15 "github.com/anthropics/anthropic-sdk-go/option"
16 "github.com/anthropics/anthropic-sdk-go/packages/param"
17 "github.com/anthropics/anthropic-sdk-go/vertex"
18 "github.com/charmbracelet/fantasy/ai"
19 "golang.org/x/oauth2/google"
20)
21
22const (
23 Name = "anthropic"
24 DefaultURL = "https://api.anthropic.com"
25)
26
27type options struct {
28 baseURL string
29 apiKey string
30 name string
31 headers map[string]string
32 client option.HTTPClient
33
34 vertexProject string
35 vertexLocation string
36 skipGoogleAuth bool
37}
38
39type provider struct {
40 options options
41}
42
43type Option = func(*options)
44
45func New(opts ...Option) ai.Provider {
46 providerOptions := options{
47 headers: map[string]string{},
48 }
49 for _, o := range opts {
50 o(&providerOptions)
51 }
52
53 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
54 providerOptions.name = cmp.Or(providerOptions.name, Name)
55 return &provider{options: providerOptions}
56}
57
58func WithBaseURL(baseURL string) Option {
59 return func(o *options) {
60 o.baseURL = baseURL
61 }
62}
63
64func WithAPIKey(apiKey string) Option {
65 return func(o *options) {
66 o.apiKey = apiKey
67 }
68}
69
70func WithVertex(project, location string) Option {
71 return func(o *options) {
72 o.vertexProject = project
73 o.vertexLocation = location
74 }
75}
76
77func WithSkipGoogleAuth(skip bool) Option {
78 return func(o *options) {
79 o.skipGoogleAuth = skip
80 }
81}
82
83func WithName(name string) Option {
84 return func(o *options) {
85 o.name = name
86 }
87}
88
89func WithHeaders(headers map[string]string) Option {
90 return func(o *options) {
91 maps.Copy(o.headers, headers)
92 }
93}
94
95func WithHTTPClient(client option.HTTPClient) Option {
96 return func(o *options) {
97 o.client = client
98 }
99}
100
101func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
102 clientOptions := make([]option.RequestOption, 0, 5+len(a.options.headers))
103 if a.options.apiKey != "" {
104 clientOptions = append(clientOptions, option.WithAPIKey(a.options.apiKey))
105 }
106 if a.options.baseURL != "" {
107 clientOptions = append(clientOptions, option.WithBaseURL(a.options.baseURL))
108 }
109 for key, value := range a.options.headers {
110 clientOptions = append(clientOptions, option.WithHeader(key, value))
111 }
112 if a.options.client != nil {
113 clientOptions = append(clientOptions, option.WithHTTPClient(a.options.client))
114 }
115 if a.options.vertexProject != "" && a.options.vertexLocation != "" {
116 var credentials *google.Credentials
117 if a.options.skipGoogleAuth {
118 credentials = &google.Credentials{TokenSource: &googleDummyTokenSource{}}
119 } else {
120 var err error
121 credentials, err = google.FindDefaultCredentials(context.Background())
122 if err != nil {
123 return nil, err
124 }
125 }
126
127 clientOptions = append(
128 clientOptions,
129 vertex.WithCredentials(
130 context.Background(),
131 a.options.vertexLocation,
132 a.options.vertexProject,
133 credentials,
134 ),
135 )
136 }
137 return languageModel{
138 modelID: modelID,
139 provider: a.options.name,
140 options: a.options,
141 client: anthropic.NewClient(clientOptions...),
142 }, nil
143}
144
145type languageModel struct {
146 provider string
147 modelID string
148 client anthropic.Client
149 options options
150}
151
152// Model implements ai.LanguageModel.
153func (a languageModel) Model() string {
154 return a.modelID
155}
156
157// Provider implements ai.LanguageModel.
158func (a languageModel) Provider() string {
159 return a.provider
160}
161
162func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
163 params := &anthropic.MessageNewParams{}
164 providerOptions := &ProviderOptions{}
165 if v, ok := call.ProviderOptions[Name]; ok {
166 providerOptions, ok = v.(*ProviderOptions)
167 if !ok {
168 return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
169 }
170 }
171 sendReasoning := true
172 if providerOptions.SendReasoning != nil {
173 sendReasoning = *providerOptions.SendReasoning
174 }
175 systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning)
176
177 if call.FrequencyPenalty != nil {
178 warnings = append(warnings, ai.CallWarning{
179 Type: ai.CallWarningTypeUnsupportedSetting,
180 Setting: "FrequencyPenalty",
181 })
182 }
183 if call.PresencePenalty != nil {
184 warnings = append(warnings, ai.CallWarning{
185 Type: ai.CallWarningTypeUnsupportedSetting,
186 Setting: "PresencePenalty",
187 })
188 }
189
190 params.System = systemBlocks
191 params.Messages = messages
192 params.Model = anthropic.Model(a.modelID)
193 params.MaxTokens = 4096
194
195 if call.MaxOutputTokens != nil {
196 params.MaxTokens = *call.MaxOutputTokens
197 }
198
199 if call.Temperature != nil {
200 params.Temperature = param.NewOpt(*call.Temperature)
201 }
202 if call.TopK != nil {
203 params.TopK = param.NewOpt(*call.TopK)
204 }
205 if call.TopP != nil {
206 params.TopP = param.NewOpt(*call.TopP)
207 }
208
209 isThinking := false
210 var thinkingBudget int64
211 if providerOptions.Thinking != nil {
212 isThinking = true
213 thinkingBudget = providerOptions.Thinking.BudgetTokens
214 }
215 if isThinking {
216 if thinkingBudget == 0 {
217 return nil, nil, ai.NewUnsupportedFunctionalityError("thinking requires budget", "")
218 }
219 params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget)
220 if call.Temperature != nil {
221 params.Temperature = param.Opt[float64]{}
222 warnings = append(warnings, ai.CallWarning{
223 Type: ai.CallWarningTypeUnsupportedSetting,
224 Setting: "temperature",
225 Details: "temperature is not supported when thinking is enabled",
226 })
227 }
228 if call.TopP != nil {
229 params.TopP = param.Opt[float64]{}
230 warnings = append(warnings, ai.CallWarning{
231 Type: ai.CallWarningTypeUnsupportedSetting,
232 Setting: "TopP",
233 Details: "TopP is not supported when thinking is enabled",
234 })
235 }
236 if call.TopK != nil {
237 params.TopK = param.Opt[int64]{}
238 warnings = append(warnings, ai.CallWarning{
239 Type: ai.CallWarningTypeUnsupportedSetting,
240 Setting: "TopK",
241 Details: "TopK is not supported when thinking is enabled",
242 })
243 }
244 params.MaxTokens = params.MaxTokens + thinkingBudget
245 }
246
247 if len(call.Tools) > 0 {
248 disableParallelToolUse := false
249 if providerOptions.DisableParallelToolUse != nil {
250 disableParallelToolUse = *providerOptions.DisableParallelToolUse
251 }
252 tools, toolChoice, toolWarnings := toTools(call.Tools, call.ToolChoice, disableParallelToolUse)
253 params.Tools = tools
254 if toolChoice != nil {
255 params.ToolChoice = *toolChoice
256 }
257 warnings = append(warnings, toolWarnings...)
258 }
259
260 return params, warnings, nil
261}
262
263func (a *provider) Name() string {
264 return Name
265}
266
267func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl {
268 if anthropicOptions, ok := providerOptions[Name]; ok {
269 if options, ok := anthropicOptions.(*ProviderCacheControlOptions); ok {
270 return &options.CacheControl
271 }
272 }
273 return nil
274}
275
276func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMetadata {
277 if anthropicOptions, ok := providerOptions[Name]; ok {
278 if reasoning, ok := anthropicOptions.(*ReasoningOptionMetadata); ok {
279 return reasoning
280 }
281 }
282 return nil
283}
284
285type messageBlock struct {
286 Role ai.MessageRole
287 Messages []ai.Message
288}
289
290func groupIntoBlocks(prompt ai.Prompt) []*messageBlock {
291 var blocks []*messageBlock
292
293 var currentBlock *messageBlock
294
295 for _, msg := range prompt {
296 switch msg.Role {
297 case ai.MessageRoleSystem:
298 if currentBlock == nil || currentBlock.Role != ai.MessageRoleSystem {
299 currentBlock = &messageBlock{
300 Role: ai.MessageRoleSystem,
301 Messages: []ai.Message{},
302 }
303 blocks = append(blocks, currentBlock)
304 }
305 currentBlock.Messages = append(currentBlock.Messages, msg)
306 case ai.MessageRoleUser:
307 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
308 currentBlock = &messageBlock{
309 Role: ai.MessageRoleUser,
310 Messages: []ai.Message{},
311 }
312 blocks = append(blocks, currentBlock)
313 }
314 currentBlock.Messages = append(currentBlock.Messages, msg)
315 case ai.MessageRoleAssistant:
316 if currentBlock == nil || currentBlock.Role != ai.MessageRoleAssistant {
317 currentBlock = &messageBlock{
318 Role: ai.MessageRoleAssistant,
319 Messages: []ai.Message{},
320 }
321 blocks = append(blocks, currentBlock)
322 }
323 currentBlock.Messages = append(currentBlock.Messages, msg)
324 case ai.MessageRoleTool:
325 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
326 currentBlock = &messageBlock{
327 Role: ai.MessageRoleUser,
328 Messages: []ai.Message{},
329 }
330 blocks = append(blocks, currentBlock)
331 }
332 currentBlock.Messages = append(currentBlock.Messages, msg)
333 }
334 }
335 return blocks
336}
337
338func toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) {
339 for _, tool := range tools {
340 if tool.GetType() == ai.ToolTypeFunction {
341 ft, ok := tool.(ai.FunctionTool)
342 if !ok {
343 continue
344 }
345 required := []string{}
346 var properties any
347 if props, ok := ft.InputSchema["properties"]; ok {
348 properties = props
349 }
350 if req, ok := ft.InputSchema["required"]; ok {
351 if reqArr, ok := req.([]string); ok {
352 required = reqArr
353 }
354 }
355 cacheControl := getCacheControl(ft.ProviderOptions)
356
357 anthropicTool := anthropic.ToolParam{
358 Name: ft.Name,
359 Description: anthropic.String(ft.Description),
360 InputSchema: anthropic.ToolInputSchemaParam{
361 Properties: properties,
362 Required: required,
363 },
364 }
365 if cacheControl != nil {
366 anthropicTool.CacheControl = anthropic.NewCacheControlEphemeralParam()
367 }
368 anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &anthropicTool})
369 continue
370 }
371 // TODO: handle provider tool calls
372 warnings = append(warnings, ai.CallWarning{
373 Type: ai.CallWarningTypeUnsupportedTool,
374 Tool: tool,
375 Message: "tool is not supported",
376 })
377 }
378 if toolChoice == nil {
379 if disableParallelToolCalls {
380 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
381 OfAuto: &anthropic.ToolChoiceAutoParam{
382 Type: "auto",
383 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
384 },
385 }
386 }
387 return anthropicTools, anthropicToolChoice, warnings
388 }
389
390 switch *toolChoice {
391 case ai.ToolChoiceAuto:
392 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
393 OfAuto: &anthropic.ToolChoiceAutoParam{
394 Type: "auto",
395 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
396 },
397 }
398 case ai.ToolChoiceRequired:
399 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
400 OfAny: &anthropic.ToolChoiceAnyParam{
401 Type: "any",
402 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
403 },
404 }
405 case ai.ToolChoiceNone:
406 return anthropicTools, anthropicToolChoice, warnings
407 default:
408 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
409 OfTool: &anthropic.ToolChoiceToolParam{
410 Type: "tool",
411 Name: string(*toolChoice),
412 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
413 },
414 }
415 }
416 return anthropicTools, anthropicToolChoice, warnings
417}
418
419func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) {
420 var systemBlocks []anthropic.TextBlockParam
421 var messages []anthropic.MessageParam
422 var warnings []ai.CallWarning
423
424 blocks := groupIntoBlocks(prompt)
425 finishedSystemBlock := false
426 for _, block := range blocks {
427 switch block.Role {
428 case ai.MessageRoleSystem:
429 if finishedSystemBlock {
430 // skip multiple system messages that are separated by user/assistant messages
431 // TODO: see if we need to send error here?
432 continue
433 }
434 finishedSystemBlock = true
435 for _, msg := range block.Messages {
436 for i, part := range msg.Content {
437 isLastPart := i == len(msg.Content)-1
438 cacheControl := getCacheControl(part.Options())
439 if cacheControl == nil && isLastPart {
440 cacheControl = getCacheControl(msg.ProviderOptions)
441 }
442 text, ok := ai.AsMessagePart[ai.TextPart](part)
443 if !ok {
444 continue
445 }
446 textBlock := anthropic.TextBlockParam{
447 Text: text.Text,
448 }
449 if cacheControl != nil {
450 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
451 }
452 systemBlocks = append(systemBlocks, textBlock)
453 }
454 }
455
456 case ai.MessageRoleUser:
457 var anthropicContent []anthropic.ContentBlockParamUnion
458 for _, msg := range block.Messages {
459 if msg.Role == ai.MessageRoleUser {
460 for i, part := range msg.Content {
461 isLastPart := i == len(msg.Content)-1
462 cacheControl := getCacheControl(part.Options())
463 if cacheControl == nil && isLastPart {
464 cacheControl = getCacheControl(msg.ProviderOptions)
465 }
466 switch part.GetType() {
467 case ai.ContentTypeText:
468 text, ok := ai.AsMessagePart[ai.TextPart](part)
469 if !ok {
470 continue
471 }
472 textBlock := &anthropic.TextBlockParam{
473 Text: text.Text,
474 }
475 if cacheControl != nil {
476 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
477 }
478 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
479 OfText: textBlock,
480 })
481 case ai.ContentTypeFile:
482 file, ok := ai.AsMessagePart[ai.FilePart](part)
483 if !ok {
484 continue
485 }
486 // TODO: handle other file types
487 if !strings.HasPrefix(file.MediaType, "image/") {
488 continue
489 }
490
491 base64Encoded := base64.StdEncoding.EncodeToString(file.Data)
492 imageBlock := anthropic.NewImageBlockBase64(file.MediaType, base64Encoded)
493 if cacheControl != nil {
494 imageBlock.OfImage.CacheControl = anthropic.NewCacheControlEphemeralParam()
495 }
496 anthropicContent = append(anthropicContent, imageBlock)
497 }
498 }
499 } else if msg.Role == ai.MessageRoleTool {
500 for i, part := range msg.Content {
501 isLastPart := i == len(msg.Content)-1
502 cacheControl := getCacheControl(part.Options())
503 if cacheControl == nil && isLastPart {
504 cacheControl = getCacheControl(msg.ProviderOptions)
505 }
506 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
507 if !ok {
508 continue
509 }
510 toolResultBlock := anthropic.ToolResultBlockParam{
511 ToolUseID: result.ToolCallID,
512 }
513 switch result.Output.GetType() {
514 case ai.ToolResultContentTypeText:
515 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
516 if !ok {
517 continue
518 }
519 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
520 {
521 OfText: &anthropic.TextBlockParam{
522 Text: content.Text,
523 },
524 },
525 }
526 case ai.ToolResultContentTypeMedia:
527 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentMedia](result.Output)
528 if !ok {
529 continue
530 }
531 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
532 {
533 OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage,
534 },
535 }
536 case ai.ToolResultContentTypeError:
537 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
538 if !ok {
539 continue
540 }
541 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
542 {
543 OfText: &anthropic.TextBlockParam{
544 Text: content.Error.Error(),
545 },
546 },
547 }
548 toolResultBlock.IsError = param.NewOpt(true)
549 }
550 if cacheControl != nil {
551 toolResultBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
552 }
553 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
554 OfToolResult: &toolResultBlock,
555 })
556 }
557 }
558 }
559 messages = append(messages, anthropic.NewUserMessage(anthropicContent...))
560 case ai.MessageRoleAssistant:
561 var anthropicContent []anthropic.ContentBlockParamUnion
562 for _, msg := range block.Messages {
563 for i, part := range msg.Content {
564 isLastPart := i == len(msg.Content)-1
565 cacheControl := getCacheControl(part.Options())
566 if cacheControl == nil && isLastPart {
567 cacheControl = getCacheControl(msg.ProviderOptions)
568 }
569 switch part.GetType() {
570 case ai.ContentTypeText:
571 text, ok := ai.AsMessagePart[ai.TextPart](part)
572 if !ok {
573 continue
574 }
575 textBlock := &anthropic.TextBlockParam{
576 Text: text.Text,
577 }
578 if cacheControl != nil {
579 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
580 }
581 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
582 OfText: textBlock,
583 })
584 case ai.ContentTypeReasoning:
585 reasoning, ok := ai.AsMessagePart[ai.ReasoningPart](part)
586 if !ok {
587 continue
588 }
589 if !sendReasoningData {
590 warnings = append(warnings, ai.CallWarning{
591 Type: "other",
592 Message: "sending reasoning content is disabled for this model",
593 })
594 continue
595 }
596 reasoningMetadata := getReasoningMetadata(part.Options())
597 if reasoningMetadata == nil {
598 warnings = append(warnings, ai.CallWarning{
599 Type: "other",
600 Message: "unsupported reasoning metadata",
601 })
602 continue
603 }
604
605 if reasoningMetadata.Signature != "" {
606 anthropicContent = append(anthropicContent, anthropic.NewThinkingBlock(reasoningMetadata.Signature, reasoning.Text))
607 } else if reasoningMetadata.RedactedData != "" {
608 anthropicContent = append(anthropicContent, anthropic.NewRedactedThinkingBlock(reasoningMetadata.RedactedData))
609 } else {
610 warnings = append(warnings, ai.CallWarning{
611 Type: "other",
612 Message: "unsupported reasoning metadata",
613 })
614 continue
615 }
616 case ai.ContentTypeToolCall:
617 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
618 if !ok {
619 continue
620 }
621 if toolCall.ProviderExecuted {
622 // TODO: implement provider executed call
623 continue
624 }
625
626 var inputMap map[string]any
627 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
628 if err != nil {
629 continue
630 }
631 toolUseBlock := anthropic.NewToolUseBlock(toolCall.ToolCallID, inputMap, toolCall.ToolName)
632 if cacheControl != nil {
633 toolUseBlock.OfToolUse.CacheControl = anthropic.NewCacheControlEphemeralParam()
634 }
635 anthropicContent = append(anthropicContent, toolUseBlock)
636 case ai.ContentTypeToolResult:
637 // TODO: implement provider executed tool result
638 }
639 }
640 }
641 messages = append(messages, anthropic.NewAssistantMessage(anthropicContent...))
642 }
643 }
644 return systemBlocks, messages, warnings
645}
646
647func (o languageModel) handleError(err error) error {
648 var apiErr *anthropic.Error
649 if errors.As(err, &apiErr) {
650 requestDump := apiErr.DumpRequest(true)
651 responseDump := apiErr.DumpResponse(true)
652 headers := map[string]string{}
653 for k, h := range apiErr.Response.Header {
654 v := h[len(h)-1]
655 headers[strings.ToLower(k)] = v
656 }
657 return ai.NewAPICallError(
658 apiErr.Error(),
659 apiErr.Request.URL.String(),
660 string(requestDump),
661 apiErr.StatusCode,
662 headers,
663 string(responseDump),
664 apiErr,
665 false,
666 )
667 }
668 return err
669}
670
671func mapFinishReason(finishReason string) ai.FinishReason {
672 switch finishReason {
673 case "end_turn", "pause_turn", "stop_sequence":
674 return ai.FinishReasonStop
675 case "max_tokens":
676 return ai.FinishReasonLength
677 case "tool_use":
678 return ai.FinishReasonToolCalls
679 default:
680 return ai.FinishReasonUnknown
681 }
682}
683
684// Generate implements ai.LanguageModel.
685func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
686 params, warnings, err := a.prepareParams(call)
687 if err != nil {
688 return nil, err
689 }
690 response, err := a.client.Messages.New(ctx, *params)
691 if err != nil {
692 return nil, a.handleError(err)
693 }
694
695 var content []ai.Content
696 for _, block := range response.Content {
697 switch block.Type {
698 case "text":
699 text, ok := block.AsAny().(anthropic.TextBlock)
700 if !ok {
701 continue
702 }
703 content = append(content, ai.TextContent{
704 Text: text.Text,
705 })
706 case "thinking":
707 reasoning, ok := block.AsAny().(anthropic.ThinkingBlock)
708 if !ok {
709 continue
710 }
711 content = append(content, ai.ReasoningContent{
712 Text: reasoning.Thinking,
713 ProviderMetadata: ai.ProviderMetadata{
714 Name: &ReasoningOptionMetadata{
715 Signature: reasoning.Signature,
716 },
717 },
718 })
719 case "redacted_thinking":
720 reasoning, ok := block.AsAny().(anthropic.RedactedThinkingBlock)
721 if !ok {
722 continue
723 }
724 content = append(content, ai.ReasoningContent{
725 Text: "",
726 ProviderMetadata: ai.ProviderMetadata{
727 Name: &ReasoningOptionMetadata{
728 RedactedData: reasoning.Data,
729 },
730 },
731 })
732 case "tool_use":
733 toolUse, ok := block.AsAny().(anthropic.ToolUseBlock)
734 if !ok {
735 continue
736 }
737 content = append(content, ai.ToolCallContent{
738 ToolCallID: toolUse.ID,
739 ToolName: toolUse.Name,
740 Input: string(toolUse.Input),
741 ProviderExecuted: false,
742 })
743 }
744 }
745
746 return &ai.Response{
747 Content: content,
748 Usage: ai.Usage{
749 InputTokens: response.Usage.InputTokens,
750 OutputTokens: response.Usage.OutputTokens,
751 TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
752 CacheCreationTokens: response.Usage.CacheCreationInputTokens,
753 CacheReadTokens: response.Usage.CacheReadInputTokens,
754 },
755 FinishReason: mapFinishReason(string(response.StopReason)),
756 ProviderMetadata: ai.ProviderMetadata{},
757 Warnings: warnings,
758 }, nil
759}
760
761// Stream implements ai.LanguageModel.
762func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
763 params, warnings, err := a.prepareParams(call)
764 if err != nil {
765 return nil, err
766 }
767
768 stream := a.client.Messages.NewStreaming(ctx, *params)
769 acc := anthropic.Message{}
770 return func(yield func(ai.StreamPart) bool) {
771 if len(warnings) > 0 {
772 if !yield(ai.StreamPart{
773 Type: ai.StreamPartTypeWarnings,
774 Warnings: warnings,
775 }) {
776 return
777 }
778 }
779
780 for stream.Next() {
781 chunk := stream.Current()
782 _ = acc.Accumulate(chunk)
783 switch chunk.Type {
784 case "content_block_start":
785 contentBlockType := chunk.ContentBlock.Type
786 switch contentBlockType {
787 case "text":
788 if !yield(ai.StreamPart{
789 Type: ai.StreamPartTypeTextStart,
790 ID: fmt.Sprintf("%d", chunk.Index),
791 }) {
792 return
793 }
794 case "thinking":
795 if !yield(ai.StreamPart{
796 Type: ai.StreamPartTypeReasoningStart,
797 ID: fmt.Sprintf("%d", chunk.Index),
798 }) {
799 return
800 }
801 case "redacted_thinking":
802 if !yield(ai.StreamPart{
803 Type: ai.StreamPartTypeReasoningStart,
804 ID: fmt.Sprintf("%d", chunk.Index),
805 ProviderMetadata: ai.ProviderMetadata{
806 Name: &ReasoningOptionMetadata{
807 RedactedData: chunk.ContentBlock.Data,
808 },
809 },
810 }) {
811 return
812 }
813 case "tool_use":
814 if !yield(ai.StreamPart{
815 Type: ai.StreamPartTypeToolInputStart,
816 ID: chunk.ContentBlock.ID,
817 ToolCallName: chunk.ContentBlock.Name,
818 ToolCallInput: "",
819 }) {
820 return
821 }
822 }
823 case "content_block_stop":
824 if len(acc.Content)-1 < int(chunk.Index) {
825 continue
826 }
827 contentBlock := acc.Content[int(chunk.Index)]
828 switch contentBlock.Type {
829 case "text":
830 if !yield(ai.StreamPart{
831 Type: ai.StreamPartTypeTextEnd,
832 ID: fmt.Sprintf("%d", chunk.Index),
833 }) {
834 return
835 }
836 case "thinking":
837 if !yield(ai.StreamPart{
838 Type: ai.StreamPartTypeReasoningEnd,
839 ID: fmt.Sprintf("%d", chunk.Index),
840 }) {
841 return
842 }
843 case "tool_use":
844 if !yield(ai.StreamPart{
845 Type: ai.StreamPartTypeToolInputEnd,
846 ID: contentBlock.ID,
847 }) {
848 return
849 }
850 if !yield(ai.StreamPart{
851 Type: ai.StreamPartTypeToolCall,
852 ID: contentBlock.ID,
853 ToolCallName: contentBlock.Name,
854 ToolCallInput: string(contentBlock.Input),
855 }) {
856 return
857 }
858 }
859 case "content_block_delta":
860 switch chunk.Delta.Type {
861 case "text_delta":
862 if !yield(ai.StreamPart{
863 Type: ai.StreamPartTypeTextDelta,
864 ID: fmt.Sprintf("%d", chunk.Index),
865 Delta: chunk.Delta.Text,
866 }) {
867 return
868 }
869 case "thinking_delta":
870 if !yield(ai.StreamPart{
871 Type: ai.StreamPartTypeReasoningDelta,
872 ID: fmt.Sprintf("%d", chunk.Index),
873 Delta: chunk.Delta.Thinking,
874 }) {
875 return
876 }
877 case "signature_delta":
878 if !yield(ai.StreamPart{
879 Type: ai.StreamPartTypeReasoningDelta,
880 ID: fmt.Sprintf("%d", chunk.Index),
881 ProviderMetadata: ai.ProviderMetadata{
882 Name: &ReasoningOptionMetadata{
883 Signature: chunk.Delta.Signature,
884 },
885 },
886 }) {
887 return
888 }
889 case "input_json_delta":
890 if len(acc.Content)-1 < int(chunk.Index) {
891 continue
892 }
893 contentBlock := acc.Content[int(chunk.Index)]
894 if !yield(ai.StreamPart{
895 Type: ai.StreamPartTypeToolInputDelta,
896 ID: contentBlock.ID,
897 ToolCallInput: chunk.Delta.PartialJSON,
898 }) {
899 return
900 }
901 }
902 case "message_stop":
903 }
904 }
905
906 err := stream.Err()
907 if err == nil || errors.Is(err, io.EOF) {
908 yield(ai.StreamPart{
909 Type: ai.StreamPartTypeFinish,
910 ID: acc.ID,
911 FinishReason: mapFinishReason(string(acc.StopReason)),
912 Usage: ai.Usage{
913 InputTokens: acc.Usage.InputTokens,
914 OutputTokens: acc.Usage.OutputTokens,
915 TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
916 CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
917 CacheReadTokens: acc.Usage.CacheReadInputTokens,
918 },
919 ProviderMetadata: ai.ProviderMetadata{},
920 })
921 return
922 } else {
923 yield(ai.StreamPart{
924 Type: ai.StreamPartTypeError,
925 Error: a.handleError(err),
926 })
927 return
928 }
929 }, nil
930}