diff --git a/claude_cli.go b/claude_cli.go index f2d4761..733d675 100644 --- a/claude_cli.go +++ b/claude_cli.go @@ -16,15 +16,40 @@ type CLIEvent struct { Type string `json:"type"` Subtype string `json:"subtype,omitempty"` + // For system init event + SessionID string `json:"session_id,omitempty"` + // For assistant message events Message *CLIMessage `json:"message,omitempty"` - // For content_block_delta - Index int `json:"index,omitempty"` - Delta *CLIDelta `json:"delta,omitempty"` + // For stream_event wrapper + Event *StreamEvent `json:"event,omitempty"` - // For result events - Result *CLIResult `json:"result,omitempty"` + // For result events — can be object or string, use RawMessage + RawResult json.RawMessage `json:"result,omitempty"` + Result *CLIResult `json:"-"` + + // Top-level cost field (present when result is string) + TotalCostUSD float64 `json:"total_cost_usd,omitempty"` +} + +// StreamEvent is the inner event inside a stream_event wrapper. +type StreamEvent struct { + Type string `json:"type"` + Index int `json:"index,omitempty"` + Delta *StreamDelta `json:"delta,omitempty"` + ContentBlock *ContentBlock `json:"content_block,omitempty"` +} + +type ContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +type StreamDelta struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` + StopReason string `json:"stop_reason,omitempty"` } type CLIMessage struct { @@ -40,18 +65,24 @@ type CLIContent struct { ID string `json:"id,omitempty"` } -type CLIDelta struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` -} - type CLIResult struct { Duration float64 `json:"duration_ms,omitempty"` NumTurns int `json:"num_turns,omitempty"` - CostUSD float64 `json:"cost_usd,omitempty"` + CostUSD float64 `json:"total_cost_usd,omitempty"` SessionID string `json:"session_id,omitempty"` } +// CLIInputMessage is the JSON format for sending messages via stream-json input. +type CLIInputMessage struct { + Type string `json:"type"` + Message CLIInputContent `json:"message"` +} + +type CLIInputContent struct { + Role string `json:"role"` + Content string `json:"content"` +} + // CLIProcess manages a running claude CLI process. type CLIProcess struct { cmd *exec.Cmd @@ -59,26 +90,33 @@ type CLIProcess struct { stdout io.ReadCloser stderr io.ReadCloser - Events chan CLIEvent - Errors chan error + Events chan CLIEvent + Errors chan error + CLISessionID string // session_id from init event, for resume done chan struct{} mu sync.Mutex } // SpawnCLI starts a new claude CLI process for the given project directory. +// Uses --input-format stream-json for multi-turn conversation support. func SpawnCLI(projectDir string) (*CLIProcess, error) { args := []string{ "-p", + "--input-format", "stream-json", "--output-format", "stream-json", "--verbose", + "--include-partial-messages", } cmd := exec.Command("claude", args...) cmd.Dir = projectDir - // Filter CLAUDECODE env var to prevent nested session detection - env := filterEnv(os.Environ(), "CLAUDECODE") + // Filter all Claude-related env vars to prevent nested session detection + env := filterEnvMulti(os.Environ(), []string{ + "CLAUDECODE", + "CLAUDE_CODE_ENTRYPOINT", + }) cmd.Env = env stdin, err := cmd.StdinPipe() @@ -116,13 +154,25 @@ func SpawnCLI(projectDir string) (*CLIProcess, error) { return cp, nil } -// Send writes a message to the claude CLI process stdin. +// Send writes a user message to the claude CLI process stdin as stream-json. func (cp *CLIProcess) Send(message string) error { cp.mu.Lock() defer cp.mu.Unlock() - msg := strings.TrimSpace(message) + "\n" - _, err := io.WriteString(cp.stdin, msg) + msg := CLIInputMessage{ + Type: "user", + Message: CLIInputContent{ + Role: "user", + Content: strings.TrimSpace(message), + }, + } + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal message: %w", err) + } + + _, err = cp.stdin.Write(append(data, '\n')) return err } @@ -132,7 +182,10 @@ func (cp *CLIProcess) Close() error { defer cp.mu.Unlock() cp.stdin.Close() - return cp.cmd.Process.Kill() + if cp.cmd.Process != nil { + return cp.cmd.Process.Kill() + } + return nil } // Done returns a channel that's closed when the process exits. @@ -158,6 +211,24 @@ func (cp *CLIProcess) readOutput() { cp.Errors <- fmt.Errorf("parse event: %w (line: %s)", err, truncate(line, 200)) continue } + + // Parse result field — can be string or object + if len(event.RawResult) > 0 { + var result CLIResult + if err := json.Unmarshal(event.RawResult, &result); err == nil { + event.Result = &result + } + // If it's a string, Result stays nil — cost comes from top-level field + if event.Result == nil && event.TotalCostUSD > 0 { + event.Result = &CLIResult{CostUSD: event.TotalCostUSD} + } + } + + // Capture session_id from init event + if event.Type == "system" && event.Subtype == "init" && event.SessionID != "" { + cp.CLISessionID = event.SessionID + } + cp.Events <- event } @@ -176,18 +247,34 @@ func (cp *CLIProcess) readErrors() { } } -// filterEnv returns a copy of env with the named variable removed. -func filterEnv(env []string, name string) []string { - prefix := name + "=" +// filterEnvMulti returns a copy of env with all named variables removed. +func filterEnvMulti(env []string, names []string) []string { + prefixes := make([]string, len(names)) + for i, n := range names { + prefixes[i] = n + "=" + } + result := make([]string, 0, len(env)) for _, e := range env { - if !strings.HasPrefix(e, prefix) { + skip := false + for _, p := range prefixes { + if strings.HasPrefix(e, p) { + skip = true + break + } + } + if !skip { result = append(result, e) } } return result } +// filterEnv returns a copy of env with the named variable removed. +func filterEnv(env []string, name string) []string { + return filterEnvMulti(env, []string{name}) +} + func truncate(s string, n int) string { if len(s) <= n { return s diff --git a/claude_cli_test.go b/claude_cli_test.go index cb0e013..53f5913 100644 --- a/claude_cli_test.go +++ b/claude_cli_test.go @@ -26,6 +26,28 @@ func TestFilterEnv(t *testing.T) { } } +func TestFilterEnvMulti(t *testing.T) { + env := []string{ + "PATH=/usr/bin", + "HOME=/root", + "CLAUDECODE=1", + "CLAUDE_CODE_ENTRYPOINT=cli", + "OTHER=value", + } + + filtered := filterEnvMulti(env, []string{"CLAUDECODE", "CLAUDE_CODE_ENTRYPOINT"}) + + if len(filtered) != 3 { + t.Fatalf("got %d entries, want 3", len(filtered)) + } + + for _, e := range filtered { + if e == "CLAUDECODE=1" || e == "CLAUDE_CODE_ENTRYPOINT=cli" { + t.Errorf("should be filtered: %s", e) + } + } +} + func TestFilterEnvNotPresent(t *testing.T) { env := []string{"PATH=/usr/bin", "HOME=/root"} filtered := filterEnv(env, "CLAUDECODE") @@ -53,6 +75,23 @@ func TestTruncate(t *testing.T) { } func TestCLIEventParsing(t *testing.T) { + t.Run("system init", func(t *testing.T) { + raw := `{"type":"system","subtype":"init","session_id":"abc-123","cwd":"/"}` + var event CLIEvent + if err := json.Unmarshal([]byte(raw), &event); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if event.Type != "system" { + t.Errorf("type = %q", event.Type) + } + if event.Subtype != "init" { + t.Errorf("subtype = %q", event.Subtype) + } + if event.SessionID != "abc-123" { + t.Errorf("session_id = %q", event.SessionID) + } + }) + t.Run("assistant message", func(t *testing.T) { raw := `{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello!"}]}}` var event CLIEvent @@ -73,29 +112,53 @@ func TestCLIEventParsing(t *testing.T) { } }) - t.Run("content_block_delta", func(t *testing.T) { - raw := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"chunk"}}` + t.Run("stream_event content_block_delta", func(t *testing.T) { + raw := `{"type":"stream_event","event":{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"chunk"}}}` var event CLIEvent if err := json.Unmarshal([]byte(raw), &event); err != nil { t.Fatalf("unmarshal: %v", err) } - if event.Type != "content_block_delta" { + if event.Type != "stream_event" { t.Errorf("type = %q", event.Type) } - if event.Delta == nil { + if event.Event == nil { + t.Fatal("event is nil") + } + if event.Event.Type != "content_block_delta" { + t.Errorf("event.type = %q", event.Event.Type) + } + if event.Event.Delta == nil { t.Fatal("delta is nil") } - if event.Delta.Text != "chunk" { - t.Errorf("text = %q", event.Delta.Text) + if event.Event.Delta.Text != "chunk" { + t.Errorf("text = %q", event.Event.Delta.Text) } }) - t.Run("result event", func(t *testing.T) { - raw := `{"type":"result","subtype":"success","result":{"duration_ms":1234.5,"num_turns":3,"cost_usd":0.05,"session_id":"abc123"}}` + t.Run("stream_event content_block_start", func(t *testing.T) { + raw := `{"type":"stream_event","event":{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}}` var event CLIEvent if err := json.Unmarshal([]byte(raw), &event); err != nil { t.Fatalf("unmarshal: %v", err) } + if event.Event.Type != "content_block_start" { + t.Errorf("event.type = %q", event.Event.Type) + } + }) + + t.Run("result event object", func(t *testing.T) { + raw := `{"type":"result","subtype":"success","result":{"duration_ms":1234.5,"num_turns":3,"total_cost_usd":0.05,"session_id":"abc123"}}` + var event CLIEvent + if err := json.Unmarshal([]byte(raw), &event); err != nil { + t.Fatalf("unmarshal: %v", err) + } + // Simulate readOutput parsing of RawResult + if len(event.RawResult) > 0 { + var result CLIResult + if err := json.Unmarshal(event.RawResult, &result); err == nil { + event.Result = &result + } + } if event.Type != "result" { t.Errorf("type = %q", event.Type) } @@ -110,6 +173,30 @@ func TestCLIEventParsing(t *testing.T) { } }) + t.Run("result event string", func(t *testing.T) { + raw := `{"type":"result","subtype":"success","total_cost_usd":0.046,"result":"hello world"}` + var event CLIEvent + if err := json.Unmarshal([]byte(raw), &event); err != nil { + t.Fatalf("unmarshal: %v", err) + } + // Simulate readOutput parsing — string result won't parse as CLIResult + if len(event.RawResult) > 0 { + var result CLIResult + if err := json.Unmarshal(event.RawResult, &result); err == nil { + event.Result = &result + } + if event.Result == nil && event.TotalCostUSD > 0 { + event.Result = &CLIResult{CostUSD: event.TotalCostUSD} + } + } + if event.Result == nil { + t.Fatal("result should not be nil for string result with cost") + } + if event.Result.CostUSD != 0.046 { + t.Errorf("cost = %f, want 0.046", event.Result.CostUSD) + } + }) + t.Run("tool_use content", func(t *testing.T) { raw := `{"type":"assistant","message":{"role":"assistant","content":[{"type":"tool_use","name":"Read","id":"tool1","input":{"file_path":"/tmp/test.go"}}]}}` var event CLIEvent @@ -124,3 +211,23 @@ func TestCLIEventParsing(t *testing.T) { } }) } + +func TestCLIInputMessage(t *testing.T) { + msg := CLIInputMessage{ + Type: "user", + Message: CLIInputContent{ + Role: "user", + Content: "hello world", + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + expected := `{"type":"user","message":{"role":"user","content":"hello world"}}` + if string(data) != expected { + t.Errorf("got %s, want %s", string(data), expected) + } +} diff --git a/fragments.go b/fragments.go index f8e2569..c59f99d 100644 --- a/fragments.go +++ b/fragments.go @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "fmt" "html" "strings" @@ -31,11 +32,82 @@ func FragmentAssistantComplete(msgID, htmlContent string) string { // FragmentToolCall returns an HTML fragment for a tool use notification. func FragmentToolCall(toolName string, toolInput string) string { escapedName := html.EscapeString(toolName) - escapedInput := html.EscapeString(toolInput) - if len(escapedInput) > 200 { - escapedInput = escapedInput[:200] + "..." + summary := formatToolSummary(toolName, toolInput) + return fmt.Sprintf(`
" + html.EscapeString(s) + ""
+ }
+ case "Glob":
+ if pat, ok := inputMap["pattern"].(string); ok {
+ return html.EscapeString(pat)
+ }
+ case "Grep":
+ parts := []string{}
+ if pat, ok := inputMap["pattern"].(string); ok {
+ parts = append(parts, pat)
+ }
+ if p, ok := inputMap["path"].(string); ok {
+ parts = append(parts, "in "+p)
+ }
+ if len(parts) > 0 {
+ return html.EscapeString(strings.Join(parts, " "))
+ }
+ case "WebSearch":
+ if q, ok := inputMap["query"].(string); ok {
+ return html.EscapeString(q)
+ }
+ case "WebFetch":
+ if u, ok := inputMap["url"].(string); ok {
+ return html.EscapeString(u)
+ }
+ }
+
+ // Fallback: show key=value pairs
+ var parts []string
+ for k, v := range inputMap {
+ s := fmt.Sprintf("%v", v)
+ if len(s) > 80 {
+ s = s[:80] + "..."
+ }
+ parts = append(parts, html.EscapeString(fmt.Sprintf("%s: %s", k, s)))
+ }
+ result := strings.Join(parts, ", ")
+ if len(result) > 300 {
+ result = result[:300] + "..."
+ }
+ return result
}
// FragmentSystemMessage returns an HTML fragment for a system message.
diff --git a/fragments_test.go b/fragments_test.go
index 2522539..f6dfa41 100644
--- a/fragments_test.go
+++ b/fragments_test.go
@@ -49,18 +49,65 @@ func TestFragmentAssistantComplete(t *testing.T) {
}
func TestFragmentToolCall(t *testing.T) {
- f := FragmentToolCall("Read", "/tmp/test.go")
+ f := FragmentToolCall("Read", `{"file_path":"/tmp/test.go"}`)
if !strings.Contains(f, "message-tool") {
t.Error("missing message-tool class")
}
if !strings.Contains(f, "Read") {
t.Error("missing tool name")
}
+ if !strings.Contains(f, "/tmp/test.go") {
+ t.Error("should show file path, not raw JSON")
+ }
+ if strings.Contains(f, "file_path") {
+ t.Error("should not show JSON key name for Read")
+ }
+}
+
+func TestFragmentToolCallBash(t *testing.T) {
+ f := FragmentToolCall("Bash", `{"command":"git status"}`)
+ if !strings.Contains(f, "") {
+ t.Error("Bash command should be in code tag")
+ }
+ if !strings.Contains(f, "git status") {
+ t.Error("should show command")
+ }
+}
+
+func TestFragmentToolCallGrep(t *testing.T) {
+ f := FragmentToolCall("Grep", `{"pattern":"TODO","path":"/root/projects"}`)
+ if !strings.Contains(f, "TODO") {
+ t.Error("should show pattern")
+ }
+ if !strings.Contains(f, "/root/projects") {
+ t.Error("should show path")
+ }
+}
+
+func TestFragmentToolCallEdit(t *testing.T) {
+ f := FragmentToolCall("Edit", `{"file_path":"/tmp/main.go","old_string":"foo","new_string":"bar"}`)
+ if !strings.Contains(f, "/tmp/main.go") {
+ t.Error("should show file path")
+ }
+}
+
+func TestFragmentToolCallFallback(t *testing.T) {
+ f := FragmentToolCall("UnknownTool", `{"key1":"value1","key2":"value2"}`)
+ if !strings.Contains(f, "key1") || !strings.Contains(f, "value1") {
+ t.Error("fallback should show key=value pairs")
+ }
+}
+
+func TestFragmentToolCallNonJSON(t *testing.T) {
+ f := FragmentToolCall("Something", "plain text input")
+ if !strings.Contains(f, "plain text input") {
+ t.Error("should show plain text for non-JSON")
+ }
}
func TestFragmentToolCallTruncation(t *testing.T) {
longInput := strings.Repeat("x", 300)
- f := FragmentToolCall("Write", longInput)
+ f := FragmentToolCall("Something", longInput)
if !strings.Contains(f, "...") {
t.Error("should truncate long input")
}
@@ -81,17 +128,11 @@ func TestFragmentTypingIndicator(t *testing.T) {
if !strings.Contains(show, "typing-indicator") {
t.Error("missing typing indicator")
}
- if !strings.Contains(show, "razmišlja") {
- t.Error("missing thinking text")
- }
hide := FragmentTypingIndicator(false)
if !strings.Contains(hide, `id="typing-indicator"`) {
t.Error("missing ID")
}
- if strings.Contains(hide, "razmišlja") {
- t.Error("should not contain thinking text when hidden")
- }
}
func TestFragmentStatus(t *testing.T) {
@@ -99,12 +140,9 @@ func TestFragmentStatus(t *testing.T) {
if !strings.Contains(connected, "connected") {
t.Error("missing connected class")
}
- if !strings.Contains(connected, "Povezan") {
- t.Error("missing connected text")
- }
disconnected := FragmentStatus(false)
- if strings.Contains(disconnected, "connected") {
+ if strings.Contains(disconnected, `class="status connected"`) {
t.Error("should not have connected class when disconnected")
}
}
@@ -114,9 +152,6 @@ func TestFragmentClearInput(t *testing.T) {
if !strings.Contains(f, `id="message-input"`) {
t.Error("missing input ID")
}
- if !strings.Contains(f, "hx-swap-oob") {
- t.Error("missing OOB swap")
- }
}
func TestFragmentCombine(t *testing.T) {
diff --git a/ws.go b/ws.go
index 6b24519..16c825c 100644
--- a/ws.go
+++ b/ws.go
@@ -51,28 +51,52 @@ func (wh *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
defer conn.Close()
+ // Single write channel — all writes go through here to avoid concurrent writes
+ writeCh := make(chan string, 100)
+ writeDone := make(chan struct{})
+ go func() {
+ defer close(writeDone)
+ for msg := range writeCh {
+ if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
+ log.Printf("WebSocket write error: %v", err)
+ return
+ }
+ }
+ }()
+
+ // Helper to send via write channel
+ send := func(text string) {
+ select {
+ case writeCh <- text:
+ default:
+ log.Printf("Write channel full, dropping message")
+ }
+ }
+
sessionID := fmt.Sprintf("%s-%s", project, r.RemoteAddr)
subID := fmt.Sprintf("ws-%d", time.Now().UnixNano())
sess, isNew, err := wh.sessionMgr.GetOrCreate(sessionID, projectDir)
if err != nil {
log.Printf("Session create error: %v", err)
- writeWSText(conn, FragmentSystemMessage(fmt.Sprintf("Greška pri pokretanju Claude-a: %v", err)))
+ send(FragmentSystemMessage(fmt.Sprintf("Greška pri pokretanju Claude-a: %v", err)))
+ close(writeCh)
+ <-writeDone
return
}
// Send status
- writeWSText(conn, FragmentStatus(true))
+ send(FragmentStatus(true))
if isNew {
- writeWSText(conn, FragmentSystemMessage("Claude sesija pokrenuta. Možeš da pišeš."))
+ send(FragmentSystemMessage("Claude sesija pokrenuta. Možeš da pišeš."))
} else {
// Replay buffer
buffer := sess.GetBuffer()
for _, msg := range buffer {
- writeWSText(conn, msg.Content)
+ send(msg.Content)
}
- writeWSText(conn, FragmentSystemMessage("Ponovo povezan. Istorija učitana."))
+ send(FragmentSystemMessage("Ponovo povezan. Istorija učitana."))
}
// Subscribe to session broadcasts
@@ -84,14 +108,10 @@ func (wh *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
go wh.listenEvents(sess)
}
- // Write pump: forward broadcast messages to this WebSocket
- wsDone := make(chan struct{})
+ // Forward broadcast messages to the write channel
go func() {
- defer close(wsDone)
for fragment := range sub.Ch {
- if err := conn.WriteMessage(websocket.TextMessage, []byte(fragment)); err != nil {
- return
- }
+ send(fragment)
}
}()
@@ -124,17 +144,19 @@ func (wh *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Timestamp: time.Now(),
})
- // Clear input and show typing — send directly to this connection only
- writeWSText(conn, FragmentCombine(FragmentClearInput(), FragmentTypingIndicator(true)))
+ // Clear input and show typing
+ send(FragmentCombine(FragmentClearInput(), FragmentTypingIndicator(true)))
// Send to claude CLI
if err := sess.Process.Send(text); err != nil {
log.Printf("Send to CLI error: %v", err)
- writeWSText(conn, FragmentSystemMessage("Greška pri slanju poruke"))
+ send(FragmentSystemMessage("Greška pri slanju poruke"))
}
}
- // Don't close session — it stays alive for reconnect
+ // Cleanup
+ close(writeCh)
+ <-writeDone
}
// listenEvents reads events from the CLI process and broadcasts via AddMessage.
@@ -156,10 +178,22 @@ func (wh *WSHandler) listenEvents(sess *ChatSession) {
}
switch event.Type {
+ case "system":
+ if event.Subtype == "init" {
+ log.Printf("CLI session started: %s", event.SessionID)
+ }
+
+ case "stream_event":
+ if event.Event == nil {
+ continue
+ }
+ wh.handleStreamEvent(sess, event.Event, ¤tMsgID, ¤tText, &msgCounter)
+
case "assistant":
if event.Message != nil {
for _, c := range event.Message.Content {
- if c.Type == "tool_use" {
+ switch c.Type {
+ case "tool_use":
inputStr := ""
if c.Input != nil {
if b, err := json.Marshal(c.Input); err == nil {
@@ -172,43 +206,21 @@ func (wh *WSHandler) listenEvents(sess *ChatSession) {
Content: fragment,
Timestamp: time.Now(),
})
+ case "text":
+ if currentMsgID != "" && currentText.Len() > 0 {
+ rendered := renderMarkdown(currentText.String())
+ fragment := FragmentAssistantComplete(currentMsgID, rendered)
+ sess.AddMessage(ChatMessage{
+ Role: "assistant",
+ Content: fragment,
+ Timestamp: time.Now(),
+ })
+ currentText.Reset()
+ }
}
}
}
- case "content_block_start":
- msgCounter++
- currentMsgID = fmt.Sprintf("msg-%d-%d", time.Now().UnixMilli(), msgCounter)
- currentText.Reset()
- fragment := FragmentAssistantStart(currentMsgID)
- sess.AddMessage(ChatMessage{
- Role: "assistant",
- Content: fragment,
- Timestamp: time.Now(),
- })
-
- case "content_block_delta":
- if event.Delta != nil && event.Delta.Text != "" {
- currentText.WriteString(event.Delta.Text)
- fragment := FragmentAssistantChunk(currentMsgID, event.Delta.Text)
- sess.AddMessage(ChatMessage{
- Role: "assistant",
- Content: fragment,
- Timestamp: time.Now(),
- })
- }
-
- case "content_block_stop":
- if currentText.Len() > 0 && currentMsgID != "" {
- rendered := renderMarkdown(currentText.String())
- fragment := FragmentAssistantComplete(currentMsgID, rendered)
- sess.AddMessage(ChatMessage{
- Role: "assistant",
- Content: fragment,
- Timestamp: time.Now(),
- })
- }
-
case "result":
fragment := FragmentTypingIndicator(false)
sess.AddMessage(ChatMessage{
@@ -235,6 +247,32 @@ func (wh *WSHandler) listenEvents(sess *ChatSession) {
}
}
+func (wh *WSHandler) handleStreamEvent(sess *ChatSession, se *StreamEvent, currentMsgID *string, currentText *strings.Builder, msgCounter *int) {
+ switch se.Type {
+ case "content_block_start":
+ *msgCounter++
+ *currentMsgID = fmt.Sprintf("msg-%d-%d", time.Now().UnixMilli(), *msgCounter)
+ currentText.Reset()
+ fragment := FragmentAssistantStart(*currentMsgID)
+ sess.AddMessage(ChatMessage{
+ Role: "assistant",
+ Content: fragment,
+ Timestamp: time.Now(),
+ })
+
+ case "content_block_delta":
+ if se.Delta != nil && se.Delta.Text != "" {
+ currentText.WriteString(se.Delta.Text)
+ fragment := FragmentAssistantChunk(*currentMsgID, se.Delta.Text)
+ sess.AddMessage(ChatMessage{
+ Role: "assistant",
+ Content: fragment,
+ Timestamp: time.Now(),
+ })
+ }
+ }
+}
+
func renderMarkdown(text string) string {
var buf bytes.Buffer
if err := goldmark.Convert([]byte(text), &buf); err != nil {
@@ -242,7 +280,3 @@ func renderMarkdown(text string) string {
}
return buf.String()
}
-
-func writeWSText(conn *websocket.Conn, text string) {
- conn.WriteMessage(websocket.TextMessage, []byte(text))
-}