1// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3package ssestream
4
5import (
6 "bufio"
7 "bytes"
8 "encoding/json"
9 "fmt"
10 "io"
11 "net/http"
12 "strings"
13)
14
15type Decoder interface {
16 Event() Event
17 Next() bool
18 Close() error
19 Err() error
20}
21
22func NewDecoder(res *http.Response) Decoder {
23 if res == nil || res.Body == nil {
24 return nil
25 }
26
27 var decoder Decoder
28 contentType := res.Header.Get("content-type")
29 if t, ok := decoderTypes[contentType]; ok {
30 decoder = t(res.Body)
31 } else {
32 scn := bufio.NewScanner(res.Body)
33 scn.Buffer(nil, bufio.MaxScanTokenSize<<4)
34 decoder = &eventStreamDecoder{rc: res.Body, scn: scn}
35 }
36 return decoder
37}
38
39var decoderTypes = map[string](func(io.ReadCloser) Decoder){}
40
41func RegisterDecoder(contentType string, decoder func(io.ReadCloser) Decoder) {
42 decoderTypes[strings.ToLower(contentType)] = decoder
43}
44
45type Event struct {
46 Type string
47 Data []byte
48}
49
50// A base implementation of a Decoder for text/event-stream.
51type eventStreamDecoder struct {
52 evt Event
53 rc io.ReadCloser
54 scn *bufio.Scanner
55 err error
56}
57
58func (s *eventStreamDecoder) Next() bool {
59 if s.err != nil {
60 return false
61 }
62
63 event := ""
64 data := bytes.NewBuffer(nil)
65
66 for s.scn.Scan() {
67 txt := s.scn.Bytes()
68
69 // Dispatch event on an empty line
70 if len(txt) == 0 {
71 s.evt = Event{
72 Type: event,
73 Data: data.Bytes(),
74 }
75 return true
76 }
77
78 // Split a string like "event: bar" into name="event" and value=" bar".
79 name, value, _ := bytes.Cut(txt, []byte(":"))
80
81 // Consume an optional space after the colon if it exists.
82 if len(value) > 0 && value[0] == ' ' {
83 value = value[1:]
84 }
85
86 switch string(name) {
87 case "":
88 // An empty line in the for ": something" is a comment and should be ignored.
89 continue
90 case "event":
91 event = string(value)
92 case "data":
93 _, s.err = data.Write(value)
94 if s.err != nil {
95 break
96 }
97 _, s.err = data.WriteRune('\n')
98 if s.err != nil {
99 break
100 }
101 }
102 }
103
104 if s.scn.Err() != nil {
105 s.err = s.scn.Err()
106 }
107
108 return false
109}
110
111func (s *eventStreamDecoder) Event() Event {
112 return s.evt
113}
114
115func (s *eventStreamDecoder) Close() error {
116 return s.rc.Close()
117}
118
119func (s *eventStreamDecoder) Err() error {
120 return s.err
121}
122
123type Stream[T any] struct {
124 decoder Decoder
125 cur T
126 err error
127}
128
129func NewStream[T any](decoder Decoder, err error) *Stream[T] {
130 return &Stream[T]{
131 decoder: decoder,
132 err: err,
133 }
134}
135
136// Next returns false if the stream has ended or an error occurred.
137// Call Stream.Current() to get the current value.
138// Call Stream.Err() to get the error.
139//
140// for stream.Next() {
141// data := stream.Current()
142// }
143//
144// if stream.Err() != nil {
145// ...
146// }
147func (s *Stream[T]) Next() bool {
148 if s.err != nil {
149 return false
150 }
151
152 for s.decoder.Next() {
153 switch s.decoder.Event().Type {
154 case "completion":
155 var nxt T
156 s.err = json.Unmarshal(s.decoder.Event().Data, &nxt)
157 if s.err != nil {
158 return false
159 }
160 s.cur = nxt
161 return true
162 case "message_start", "message_delta", "message_stop", "content_block_start", "content_block_delta", "content_block_stop":
163 var nxt T
164 s.err = json.Unmarshal(s.decoder.Event().Data, &nxt)
165 if s.err != nil {
166 return false
167 }
168 s.cur = nxt
169 return true
170 case "ping":
171 continue
172 case "error":
173 s.err = fmt.Errorf("received error while streaming: %s", string(s.decoder.Event().Data))
174 return false
175 }
176 }
177
178 // decoder.Next() may be false because of an error
179 s.err = s.decoder.Err()
180
181 return false
182}
183
184func (s *Stream[T]) Current() T {
185 return s.cur
186}
187
188func (s *Stream[T]) Err() error {
189 return s.err
190}
191
192func (s *Stream[T]) Close() error {
193 if s.decoder == nil {
194 // already closed
195 return nil
196 }
197 return s.decoder.Close()
198}