Popravljen CLI mod: stream-json input, concurrent write fix, lepsi tool prikaz
All checks were successful
Tests / unit-tests (push) Successful in 41s

- Koristi --input-format stream-json za multi-turn razgovor
- Koristi --include-partial-messages za streaming chunk-ove
- Filtrira CLAUDECODE i CLAUDE_CODE_ENTRYPOINT env varijable
- Svi WS write-ovi idu kroz jedan kanal (nema concurrent write panic)
- Tool call prikaz: Read prikazuje putanju, Bash prikazuje komandu, itd
- result polje moze biti string ili objekat (oba obradjena)
- Subscriber/broadcast model za real-time push

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
djuka 2026-02-18 05:27:34 +00:00
parent 3283888738
commit 9d0e507689
5 changed files with 438 additions and 103 deletions

View File

@ -16,15 +16,40 @@ type CLIEvent struct {
Type string `json:"type"` Type string `json:"type"`
Subtype string `json:"subtype,omitempty"` Subtype string `json:"subtype,omitempty"`
// For system init event
SessionID string `json:"session_id,omitempty"`
// For assistant message events // For assistant message events
Message *CLIMessage `json:"message,omitempty"` Message *CLIMessage `json:"message,omitempty"`
// For content_block_delta // For stream_event wrapper
Index int `json:"index,omitempty"` Event *StreamEvent `json:"event,omitempty"`
Delta *CLIDelta `json:"delta,omitempty"`
// For result events // For result events — can be object or string, use RawMessage
Result *CLIResult `json:"result,omitempty"` RawResult json.RawMessage `json:"result,omitempty"`
Result *CLIResult `json:"-"`
// Top-level cost field (present when result is string)
TotalCostUSD float64 `json:"total_cost_usd,omitempty"`
}
// StreamEvent is the inner event inside a stream_event wrapper.
type StreamEvent struct {
Type string `json:"type"`
Index int `json:"index,omitempty"`
Delta *StreamDelta `json:"delta,omitempty"`
ContentBlock *ContentBlock `json:"content_block,omitempty"`
}
type ContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}
type StreamDelta struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
} }
type CLIMessage struct { type CLIMessage struct {
@ -40,18 +65,24 @@ type CLIContent struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
} }
type CLIDelta struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}
type CLIResult struct { type CLIResult struct {
Duration float64 `json:"duration_ms,omitempty"` Duration float64 `json:"duration_ms,omitempty"`
NumTurns int `json:"num_turns,omitempty"` NumTurns int `json:"num_turns,omitempty"`
CostUSD float64 `json:"cost_usd,omitempty"` CostUSD float64 `json:"total_cost_usd,omitempty"`
SessionID string `json:"session_id,omitempty"` SessionID string `json:"session_id,omitempty"`
} }
// CLIInputMessage is the JSON format for sending messages via stream-json input.
type CLIInputMessage struct {
Type string `json:"type"`
Message CLIInputContent `json:"message"`
}
type CLIInputContent struct {
Role string `json:"role"`
Content string `json:"content"`
}
// CLIProcess manages a running claude CLI process. // CLIProcess manages a running claude CLI process.
type CLIProcess struct { type CLIProcess struct {
cmd *exec.Cmd cmd *exec.Cmd
@ -59,26 +90,33 @@ type CLIProcess struct {
stdout io.ReadCloser stdout io.ReadCloser
stderr io.ReadCloser stderr io.ReadCloser
Events chan CLIEvent Events chan CLIEvent
Errors chan error Errors chan error
CLISessionID string // session_id from init event, for resume
done chan struct{} done chan struct{}
mu sync.Mutex mu sync.Mutex
} }
// SpawnCLI starts a new claude CLI process for the given project directory. // SpawnCLI starts a new claude CLI process for the given project directory.
// Uses --input-format stream-json for multi-turn conversation support.
func SpawnCLI(projectDir string) (*CLIProcess, error) { func SpawnCLI(projectDir string) (*CLIProcess, error) {
args := []string{ args := []string{
"-p", "-p",
"--input-format", "stream-json",
"--output-format", "stream-json", "--output-format", "stream-json",
"--verbose", "--verbose",
"--include-partial-messages",
} }
cmd := exec.Command("claude", args...) cmd := exec.Command("claude", args...)
cmd.Dir = projectDir cmd.Dir = projectDir
// Filter CLAUDECODE env var to prevent nested session detection // Filter all Claude-related env vars to prevent nested session detection
env := filterEnv(os.Environ(), "CLAUDECODE") env := filterEnvMulti(os.Environ(), []string{
"CLAUDECODE",
"CLAUDE_CODE_ENTRYPOINT",
})
cmd.Env = env cmd.Env = env
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
@ -116,13 +154,25 @@ func SpawnCLI(projectDir string) (*CLIProcess, error) {
return cp, nil return cp, nil
} }
// Send writes a message to the claude CLI process stdin. // Send writes a user message to the claude CLI process stdin as stream-json.
func (cp *CLIProcess) Send(message string) error { func (cp *CLIProcess) Send(message string) error {
cp.mu.Lock() cp.mu.Lock()
defer cp.mu.Unlock() defer cp.mu.Unlock()
msg := strings.TrimSpace(message) + "\n" msg := CLIInputMessage{
_, err := io.WriteString(cp.stdin, msg) Type: "user",
Message: CLIInputContent{
Role: "user",
Content: strings.TrimSpace(message),
},
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal message: %w", err)
}
_, err = cp.stdin.Write(append(data, '\n'))
return err return err
} }
@ -132,7 +182,10 @@ func (cp *CLIProcess) Close() error {
defer cp.mu.Unlock() defer cp.mu.Unlock()
cp.stdin.Close() cp.stdin.Close()
return cp.cmd.Process.Kill() if cp.cmd.Process != nil {
return cp.cmd.Process.Kill()
}
return nil
} }
// Done returns a channel that's closed when the process exits. // Done returns a channel that's closed when the process exits.
@ -158,6 +211,24 @@ func (cp *CLIProcess) readOutput() {
cp.Errors <- fmt.Errorf("parse event: %w (line: %s)", err, truncate(line, 200)) cp.Errors <- fmt.Errorf("parse event: %w (line: %s)", err, truncate(line, 200))
continue continue
} }
// Parse result field — can be string or object
if len(event.RawResult) > 0 {
var result CLIResult
if err := json.Unmarshal(event.RawResult, &result); err == nil {
event.Result = &result
}
// If it's a string, Result stays nil — cost comes from top-level field
if event.Result == nil && event.TotalCostUSD > 0 {
event.Result = &CLIResult{CostUSD: event.TotalCostUSD}
}
}
// Capture session_id from init event
if event.Type == "system" && event.Subtype == "init" && event.SessionID != "" {
cp.CLISessionID = event.SessionID
}
cp.Events <- event cp.Events <- event
} }
@ -176,18 +247,34 @@ func (cp *CLIProcess) readErrors() {
} }
} }
// filterEnv returns a copy of env with the named variable removed. // filterEnvMulti returns a copy of env with all named variables removed.
func filterEnv(env []string, name string) []string { func filterEnvMulti(env []string, names []string) []string {
prefix := name + "=" prefixes := make([]string, len(names))
for i, n := range names {
prefixes[i] = n + "="
}
result := make([]string, 0, len(env)) result := make([]string, 0, len(env))
for _, e := range env { for _, e := range env {
if !strings.HasPrefix(e, prefix) { skip := false
for _, p := range prefixes {
if strings.HasPrefix(e, p) {
skip = true
break
}
}
if !skip {
result = append(result, e) result = append(result, e)
} }
} }
return result return result
} }
// filterEnv returns a copy of env with the named variable removed.
func filterEnv(env []string, name string) []string {
return filterEnvMulti(env, []string{name})
}
func truncate(s string, n int) string { func truncate(s string, n int) string {
if len(s) <= n { if len(s) <= n {
return s return s

View File

@ -26,6 +26,28 @@ func TestFilterEnv(t *testing.T) {
} }
} }
func TestFilterEnvMulti(t *testing.T) {
env := []string{
"PATH=/usr/bin",
"HOME=/root",
"CLAUDECODE=1",
"CLAUDE_CODE_ENTRYPOINT=cli",
"OTHER=value",
}
filtered := filterEnvMulti(env, []string{"CLAUDECODE", "CLAUDE_CODE_ENTRYPOINT"})
if len(filtered) != 3 {
t.Fatalf("got %d entries, want 3", len(filtered))
}
for _, e := range filtered {
if e == "CLAUDECODE=1" || e == "CLAUDE_CODE_ENTRYPOINT=cli" {
t.Errorf("should be filtered: %s", e)
}
}
}
func TestFilterEnvNotPresent(t *testing.T) { func TestFilterEnvNotPresent(t *testing.T) {
env := []string{"PATH=/usr/bin", "HOME=/root"} env := []string{"PATH=/usr/bin", "HOME=/root"}
filtered := filterEnv(env, "CLAUDECODE") filtered := filterEnv(env, "CLAUDECODE")
@ -53,6 +75,23 @@ func TestTruncate(t *testing.T) {
} }
func TestCLIEventParsing(t *testing.T) { func TestCLIEventParsing(t *testing.T) {
t.Run("system init", func(t *testing.T) {
raw := `{"type":"system","subtype":"init","session_id":"abc-123","cwd":"/"}`
var event CLIEvent
if err := json.Unmarshal([]byte(raw), &event); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if event.Type != "system" {
t.Errorf("type = %q", event.Type)
}
if event.Subtype != "init" {
t.Errorf("subtype = %q", event.Subtype)
}
if event.SessionID != "abc-123" {
t.Errorf("session_id = %q", event.SessionID)
}
})
t.Run("assistant message", func(t *testing.T) { t.Run("assistant message", func(t *testing.T) {
raw := `{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello!"}]}}` raw := `{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello!"}]}}`
var event CLIEvent var event CLIEvent
@ -73,29 +112,53 @@ func TestCLIEventParsing(t *testing.T) {
} }
}) })
t.Run("content_block_delta", func(t *testing.T) { t.Run("stream_event content_block_delta", func(t *testing.T) {
raw := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"chunk"}}` raw := `{"type":"stream_event","event":{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"chunk"}}}`
var event CLIEvent var event CLIEvent
if err := json.Unmarshal([]byte(raw), &event); err != nil { if err := json.Unmarshal([]byte(raw), &event); err != nil {
t.Fatalf("unmarshal: %v", err) t.Fatalf("unmarshal: %v", err)
} }
if event.Type != "content_block_delta" { if event.Type != "stream_event" {
t.Errorf("type = %q", event.Type) t.Errorf("type = %q", event.Type)
} }
if event.Delta == nil { if event.Event == nil {
t.Fatal("event is nil")
}
if event.Event.Type != "content_block_delta" {
t.Errorf("event.type = %q", event.Event.Type)
}
if event.Event.Delta == nil {
t.Fatal("delta is nil") t.Fatal("delta is nil")
} }
if event.Delta.Text != "chunk" { if event.Event.Delta.Text != "chunk" {
t.Errorf("text = %q", event.Delta.Text) t.Errorf("text = %q", event.Event.Delta.Text)
} }
}) })
t.Run("result event", func(t *testing.T) { t.Run("stream_event content_block_start", func(t *testing.T) {
raw := `{"type":"result","subtype":"success","result":{"duration_ms":1234.5,"num_turns":3,"cost_usd":0.05,"session_id":"abc123"}}` raw := `{"type":"stream_event","event":{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}}`
var event CLIEvent var event CLIEvent
if err := json.Unmarshal([]byte(raw), &event); err != nil { if err := json.Unmarshal([]byte(raw), &event); err != nil {
t.Fatalf("unmarshal: %v", err) t.Fatalf("unmarshal: %v", err)
} }
if event.Event.Type != "content_block_start" {
t.Errorf("event.type = %q", event.Event.Type)
}
})
t.Run("result event object", func(t *testing.T) {
raw := `{"type":"result","subtype":"success","result":{"duration_ms":1234.5,"num_turns":3,"total_cost_usd":0.05,"session_id":"abc123"}}`
var event CLIEvent
if err := json.Unmarshal([]byte(raw), &event); err != nil {
t.Fatalf("unmarshal: %v", err)
}
// Simulate readOutput parsing of RawResult
if len(event.RawResult) > 0 {
var result CLIResult
if err := json.Unmarshal(event.RawResult, &result); err == nil {
event.Result = &result
}
}
if event.Type != "result" { if event.Type != "result" {
t.Errorf("type = %q", event.Type) t.Errorf("type = %q", event.Type)
} }
@ -110,6 +173,30 @@ func TestCLIEventParsing(t *testing.T) {
} }
}) })
t.Run("result event string", func(t *testing.T) {
raw := `{"type":"result","subtype":"success","total_cost_usd":0.046,"result":"hello world"}`
var event CLIEvent
if err := json.Unmarshal([]byte(raw), &event); err != nil {
t.Fatalf("unmarshal: %v", err)
}
// Simulate readOutput parsing — string result won't parse as CLIResult
if len(event.RawResult) > 0 {
var result CLIResult
if err := json.Unmarshal(event.RawResult, &result); err == nil {
event.Result = &result
}
if event.Result == nil && event.TotalCostUSD > 0 {
event.Result = &CLIResult{CostUSD: event.TotalCostUSD}
}
}
if event.Result == nil {
t.Fatal("result should not be nil for string result with cost")
}
if event.Result.CostUSD != 0.046 {
t.Errorf("cost = %f, want 0.046", event.Result.CostUSD)
}
})
t.Run("tool_use content", func(t *testing.T) { t.Run("tool_use content", func(t *testing.T) {
raw := `{"type":"assistant","message":{"role":"assistant","content":[{"type":"tool_use","name":"Read","id":"tool1","input":{"file_path":"/tmp/test.go"}}]}}` raw := `{"type":"assistant","message":{"role":"assistant","content":[{"type":"tool_use","name":"Read","id":"tool1","input":{"file_path":"/tmp/test.go"}}]}}`
var event CLIEvent var event CLIEvent
@ -124,3 +211,23 @@ func TestCLIEventParsing(t *testing.T) {
} }
}) })
} }
func TestCLIInputMessage(t *testing.T) {
msg := CLIInputMessage{
Type: "user",
Message: CLIInputContent{
Role: "user",
Content: "hello world",
},
}
data, err := json.Marshal(msg)
if err != nil {
t.Fatalf("marshal: %v", err)
}
expected := `{"type":"user","message":{"role":"user","content":"hello world"}}`
if string(data) != expected {
t.Errorf("got %s, want %s", string(data), expected)
}
}

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"html" "html"
"strings" "strings"
@ -31,11 +32,82 @@ func FragmentAssistantComplete(msgID, htmlContent string) string {
// FragmentToolCall returns an HTML fragment for a tool use notification. // FragmentToolCall returns an HTML fragment for a tool use notification.
func FragmentToolCall(toolName string, toolInput string) string { func FragmentToolCall(toolName string, toolInput string) string {
escapedName := html.EscapeString(toolName) escapedName := html.EscapeString(toolName)
escapedInput := html.EscapeString(toolInput) summary := formatToolSummary(toolName, toolInput)
if len(escapedInput) > 200 { return fmt.Sprintf(`<div id="chat-messages" hx-swap-oob="beforeend"><div class="message message-tool"><div class="tool-name">%s</div><div>%s</div></div></div>`, escapedName, summary)
escapedInput = escapedInput[:200] + "..." }
// formatToolSummary produces a human-readable summary of a tool call.
func formatToolSummary(toolName, rawInput string) string {
var inputMap map[string]any
if err := json.Unmarshal([]byte(rawInput), &inputMap); err != nil {
// Not JSON, just escape and truncate
s := html.EscapeString(rawInput)
if len(s) > 200 {
s = s[:200] + "..."
}
return s
} }
return fmt.Sprintf(`<div id="chat-messages" hx-swap-oob="beforeend"><div class="message message-tool"><div class="tool-name">%s</div><div>%s</div></div></div>`, escapedName, escapedInput)
switch toolName {
case "Read":
if fp, ok := inputMap["file_path"].(string); ok {
return html.EscapeString(fp)
}
case "Edit":
if fp, ok := inputMap["file_path"].(string); ok {
return html.EscapeString(fmt.Sprintf("%s", fp))
}
case "Write":
if fp, ok := inputMap["file_path"].(string); ok {
return html.EscapeString(fp)
}
case "Bash":
if cmd, ok := inputMap["command"].(string); ok {
s := cmd
if len(s) > 150 {
s = s[:150] + "..."
}
return "<code>" + html.EscapeString(s) + "</code>"
}
case "Glob":
if pat, ok := inputMap["pattern"].(string); ok {
return html.EscapeString(pat)
}
case "Grep":
parts := []string{}
if pat, ok := inputMap["pattern"].(string); ok {
parts = append(parts, pat)
}
if p, ok := inputMap["path"].(string); ok {
parts = append(parts, "in "+p)
}
if len(parts) > 0 {
return html.EscapeString(strings.Join(parts, " "))
}
case "WebSearch":
if q, ok := inputMap["query"].(string); ok {
return html.EscapeString(q)
}
case "WebFetch":
if u, ok := inputMap["url"].(string); ok {
return html.EscapeString(u)
}
}
// Fallback: show key=value pairs
var parts []string
for k, v := range inputMap {
s := fmt.Sprintf("%v", v)
if len(s) > 80 {
s = s[:80] + "..."
}
parts = append(parts, html.EscapeString(fmt.Sprintf("%s: %s", k, s)))
}
result := strings.Join(parts, ", ")
if len(result) > 300 {
result = result[:300] + "..."
}
return result
} }
// FragmentSystemMessage returns an HTML fragment for a system message. // FragmentSystemMessage returns an HTML fragment for a system message.

View File

@ -49,18 +49,65 @@ func TestFragmentAssistantComplete(t *testing.T) {
} }
func TestFragmentToolCall(t *testing.T) { func TestFragmentToolCall(t *testing.T) {
f := FragmentToolCall("Read", "/tmp/test.go") f := FragmentToolCall("Read", `{"file_path":"/tmp/test.go"}`)
if !strings.Contains(f, "message-tool") { if !strings.Contains(f, "message-tool") {
t.Error("missing message-tool class") t.Error("missing message-tool class")
} }
if !strings.Contains(f, "Read") { if !strings.Contains(f, "Read") {
t.Error("missing tool name") t.Error("missing tool name")
} }
if !strings.Contains(f, "/tmp/test.go") {
t.Error("should show file path, not raw JSON")
}
if strings.Contains(f, "file_path") {
t.Error("should not show JSON key name for Read")
}
}
func TestFragmentToolCallBash(t *testing.T) {
f := FragmentToolCall("Bash", `{"command":"git status"}`)
if !strings.Contains(f, "<code>") {
t.Error("Bash command should be in code tag")
}
if !strings.Contains(f, "git status") {
t.Error("should show command")
}
}
func TestFragmentToolCallGrep(t *testing.T) {
f := FragmentToolCall("Grep", `{"pattern":"TODO","path":"/root/projects"}`)
if !strings.Contains(f, "TODO") {
t.Error("should show pattern")
}
if !strings.Contains(f, "/root/projects") {
t.Error("should show path")
}
}
func TestFragmentToolCallEdit(t *testing.T) {
f := FragmentToolCall("Edit", `{"file_path":"/tmp/main.go","old_string":"foo","new_string":"bar"}`)
if !strings.Contains(f, "/tmp/main.go") {
t.Error("should show file path")
}
}
func TestFragmentToolCallFallback(t *testing.T) {
f := FragmentToolCall("UnknownTool", `{"key1":"value1","key2":"value2"}`)
if !strings.Contains(f, "key1") || !strings.Contains(f, "value1") {
t.Error("fallback should show key=value pairs")
}
}
func TestFragmentToolCallNonJSON(t *testing.T) {
f := FragmentToolCall("Something", "plain text input")
if !strings.Contains(f, "plain text input") {
t.Error("should show plain text for non-JSON")
}
} }
func TestFragmentToolCallTruncation(t *testing.T) { func TestFragmentToolCallTruncation(t *testing.T) {
longInput := strings.Repeat("x", 300) longInput := strings.Repeat("x", 300)
f := FragmentToolCall("Write", longInput) f := FragmentToolCall("Something", longInput)
if !strings.Contains(f, "...") { if !strings.Contains(f, "...") {
t.Error("should truncate long input") t.Error("should truncate long input")
} }
@ -81,17 +128,11 @@ func TestFragmentTypingIndicator(t *testing.T) {
if !strings.Contains(show, "typing-indicator") { if !strings.Contains(show, "typing-indicator") {
t.Error("missing typing indicator") t.Error("missing typing indicator")
} }
if !strings.Contains(show, "razmišlja") {
t.Error("missing thinking text")
}
hide := FragmentTypingIndicator(false) hide := FragmentTypingIndicator(false)
if !strings.Contains(hide, `id="typing-indicator"`) { if !strings.Contains(hide, `id="typing-indicator"`) {
t.Error("missing ID") t.Error("missing ID")
} }
if strings.Contains(hide, "razmišlja") {
t.Error("should not contain thinking text when hidden")
}
} }
func TestFragmentStatus(t *testing.T) { func TestFragmentStatus(t *testing.T) {
@ -99,12 +140,9 @@ func TestFragmentStatus(t *testing.T) {
if !strings.Contains(connected, "connected") { if !strings.Contains(connected, "connected") {
t.Error("missing connected class") t.Error("missing connected class")
} }
if !strings.Contains(connected, "Povezan") {
t.Error("missing connected text")
}
disconnected := FragmentStatus(false) disconnected := FragmentStatus(false)
if strings.Contains(disconnected, "connected") { if strings.Contains(disconnected, `class="status connected"`) {
t.Error("should not have connected class when disconnected") t.Error("should not have connected class when disconnected")
} }
} }
@ -114,9 +152,6 @@ func TestFragmentClearInput(t *testing.T) {
if !strings.Contains(f, `id="message-input"`) { if !strings.Contains(f, `id="message-input"`) {
t.Error("missing input ID") t.Error("missing input ID")
} }
if !strings.Contains(f, "hx-swap-oob") {
t.Error("missing OOB swap")
}
} }
func TestFragmentCombine(t *testing.T) { func TestFragmentCombine(t *testing.T) {

140
ws.go
View File

@ -51,28 +51,52 @@ func (wh *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
defer conn.Close() defer conn.Close()
// Single write channel — all writes go through here to avoid concurrent writes
writeCh := make(chan string, 100)
writeDone := make(chan struct{})
go func() {
defer close(writeDone)
for msg := range writeCh {
if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
log.Printf("WebSocket write error: %v", err)
return
}
}
}()
// Helper to send via write channel
send := func(text string) {
select {
case writeCh <- text:
default:
log.Printf("Write channel full, dropping message")
}
}
sessionID := fmt.Sprintf("%s-%s", project, r.RemoteAddr) sessionID := fmt.Sprintf("%s-%s", project, r.RemoteAddr)
subID := fmt.Sprintf("ws-%d", time.Now().UnixNano()) subID := fmt.Sprintf("ws-%d", time.Now().UnixNano())
sess, isNew, err := wh.sessionMgr.GetOrCreate(sessionID, projectDir) sess, isNew, err := wh.sessionMgr.GetOrCreate(sessionID, projectDir)
if err != nil { if err != nil {
log.Printf("Session create error: %v", err) log.Printf("Session create error: %v", err)
writeWSText(conn, FragmentSystemMessage(fmt.Sprintf("Greška pri pokretanju Claude-a: %v", err))) send(FragmentSystemMessage(fmt.Sprintf("Greška pri pokretanju Claude-a: %v", err)))
close(writeCh)
<-writeDone
return return
} }
// Send status // Send status
writeWSText(conn, FragmentStatus(true)) send(FragmentStatus(true))
if isNew { if isNew {
writeWSText(conn, FragmentSystemMessage("Claude sesija pokrenuta. Možeš da pišeš.")) send(FragmentSystemMessage("Claude sesija pokrenuta. Možeš da pišeš."))
} else { } else {
// Replay buffer // Replay buffer
buffer := sess.GetBuffer() buffer := sess.GetBuffer()
for _, msg := range buffer { for _, msg := range buffer {
writeWSText(conn, msg.Content) send(msg.Content)
} }
writeWSText(conn, FragmentSystemMessage("Ponovo povezan. Istorija učitana.")) send(FragmentSystemMessage("Ponovo povezan. Istorija učitana."))
} }
// Subscribe to session broadcasts // Subscribe to session broadcasts
@ -84,14 +108,10 @@ func (wh *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
go wh.listenEvents(sess) go wh.listenEvents(sess)
} }
// Write pump: forward broadcast messages to this WebSocket // Forward broadcast messages to the write channel
wsDone := make(chan struct{})
go func() { go func() {
defer close(wsDone)
for fragment := range sub.Ch { for fragment := range sub.Ch {
if err := conn.WriteMessage(websocket.TextMessage, []byte(fragment)); err != nil { send(fragment)
return
}
} }
}() }()
@ -124,17 +144,19 @@ func (wh *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Timestamp: time.Now(), Timestamp: time.Now(),
}) })
// Clear input and show typing — send directly to this connection only // Clear input and show typing
writeWSText(conn, FragmentCombine(FragmentClearInput(), FragmentTypingIndicator(true))) send(FragmentCombine(FragmentClearInput(), FragmentTypingIndicator(true)))
// Send to claude CLI // Send to claude CLI
if err := sess.Process.Send(text); err != nil { if err := sess.Process.Send(text); err != nil {
log.Printf("Send to CLI error: %v", err) log.Printf("Send to CLI error: %v", err)
writeWSText(conn, FragmentSystemMessage("Greška pri slanju poruke")) send(FragmentSystemMessage("Greška pri slanju poruke"))
} }
} }
// Don't close session — it stays alive for reconnect // Cleanup
close(writeCh)
<-writeDone
} }
// listenEvents reads events from the CLI process and broadcasts via AddMessage. // listenEvents reads events from the CLI process and broadcasts via AddMessage.
@ -156,10 +178,22 @@ func (wh *WSHandler) listenEvents(sess *ChatSession) {
} }
switch event.Type { switch event.Type {
case "system":
if event.Subtype == "init" {
log.Printf("CLI session started: %s", event.SessionID)
}
case "stream_event":
if event.Event == nil {
continue
}
wh.handleStreamEvent(sess, event.Event, &currentMsgID, &currentText, &msgCounter)
case "assistant": case "assistant":
if event.Message != nil { if event.Message != nil {
for _, c := range event.Message.Content { for _, c := range event.Message.Content {
if c.Type == "tool_use" { switch c.Type {
case "tool_use":
inputStr := "" inputStr := ""
if c.Input != nil { if c.Input != nil {
if b, err := json.Marshal(c.Input); err == nil { if b, err := json.Marshal(c.Input); err == nil {
@ -172,43 +206,21 @@ func (wh *WSHandler) listenEvents(sess *ChatSession) {
Content: fragment, Content: fragment,
Timestamp: time.Now(), Timestamp: time.Now(),
}) })
case "text":
if currentMsgID != "" && currentText.Len() > 0 {
rendered := renderMarkdown(currentText.String())
fragment := FragmentAssistantComplete(currentMsgID, rendered)
sess.AddMessage(ChatMessage{
Role: "assistant",
Content: fragment,
Timestamp: time.Now(),
})
currentText.Reset()
}
} }
} }
} }
case "content_block_start":
msgCounter++
currentMsgID = fmt.Sprintf("msg-%d-%d", time.Now().UnixMilli(), msgCounter)
currentText.Reset()
fragment := FragmentAssistantStart(currentMsgID)
sess.AddMessage(ChatMessage{
Role: "assistant",
Content: fragment,
Timestamp: time.Now(),
})
case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" {
currentText.WriteString(event.Delta.Text)
fragment := FragmentAssistantChunk(currentMsgID, event.Delta.Text)
sess.AddMessage(ChatMessage{
Role: "assistant",
Content: fragment,
Timestamp: time.Now(),
})
}
case "content_block_stop":
if currentText.Len() > 0 && currentMsgID != "" {
rendered := renderMarkdown(currentText.String())
fragment := FragmentAssistantComplete(currentMsgID, rendered)
sess.AddMessage(ChatMessage{
Role: "assistant",
Content: fragment,
Timestamp: time.Now(),
})
}
case "result": case "result":
fragment := FragmentTypingIndicator(false) fragment := FragmentTypingIndicator(false)
sess.AddMessage(ChatMessage{ sess.AddMessage(ChatMessage{
@ -235,6 +247,32 @@ func (wh *WSHandler) listenEvents(sess *ChatSession) {
} }
} }
func (wh *WSHandler) handleStreamEvent(sess *ChatSession, se *StreamEvent, currentMsgID *string, currentText *strings.Builder, msgCounter *int) {
switch se.Type {
case "content_block_start":
*msgCounter++
*currentMsgID = fmt.Sprintf("msg-%d-%d", time.Now().UnixMilli(), *msgCounter)
currentText.Reset()
fragment := FragmentAssistantStart(*currentMsgID)
sess.AddMessage(ChatMessage{
Role: "assistant",
Content: fragment,
Timestamp: time.Now(),
})
case "content_block_delta":
if se.Delta != nil && se.Delta.Text != "" {
currentText.WriteString(se.Delta.Text)
fragment := FragmentAssistantChunk(*currentMsgID, se.Delta.Text)
sess.AddMessage(ChatMessage{
Role: "assistant",
Content: fragment,
Timestamp: time.Now(),
})
}
}
}
func renderMarkdown(text string) string { func renderMarkdown(text string) string {
var buf bytes.Buffer var buf bytes.Buffer
if err := goldmark.Convert([]byte(text), &buf); err != nil { if err := goldmark.Convert([]byte(text), &buf); err != nil {
@ -242,7 +280,3 @@ func renderMarkdown(text string) string {
} }
return buf.String() return buf.String()
} }
func writeWSText(conn *websocket.Conn, text string) {
conn.WriteMessage(websocket.TextMessage, []byte(text))
}