// uds_server_test.go 测试 UDSServer 的核心路径. package inbox import ( "encoding/json" "fmt" "net" "os" "path/filepath" "testing" "time" ) // dial 是测试辅助函数:连接 UDS 并发送一条 JSON 消息. func dial(t *testing.T, sockPath string, msg UDSInboxMessage) error { t.Helper() conn, err := net.Dial("unix", sockPath) if err != nil { return fmt.Errorf("dial: %w", err) } defer conn.Close() return json.NewEncoder(conn).Encode(msg) } // TestUDSServer_Normal 测试正常路径:Start → 发消息 → 收到消息. func TestUDSServer_Normal(t *testing.T) { srv, err := NewUDSServer("test-session-normal") if err != nil { t.Fatalf("NewUDSServer: %v", err) } defer srv.Close() if err := srv.Start(); err != nil { t.Fatalf("Start: %v", err) } // 确认 socket 文件已创建 if _, err := os.Stat(srv.SockPath()); err != nil { t.Fatalf("socket file not created: %v", err) } // 发送消息 msg := UDSInboxMessage{ Type: "progress", ToolUseID: "tool_abc_123", Data: "50%", } if err := dial(t, srv.SockPath(), msg); err != nil { t.Fatalf("dial: %v", err) } // 等待接收 select { case got := <-srv.Messages(): if got.Type != "progress" { t.Errorf("got type=%q, want %q", got.Type, "progress") } if got.ToolUseID != "tool_abc_123" { t.Errorf("got tool_use_id=%q, want %q", got.ToolUseID, "tool_abc_123") } if got.Data != "50%" { t.Errorf("got data=%q, want %q", got.Data, "50%") } case <-time.After(2 * time.Second): t.Fatal("timeout waiting for message") } } // TestUDSServer_MultipleMessages 测试多条消息依次到达. func TestUDSServer_MultipleMessages(t *testing.T) { srv, err := NewUDSServer("test-session-multi") if err != nil { t.Fatalf("NewUDSServer: %v", err) } defer srv.Close() if err := srv.Start(); err != nil { t.Fatalf("Start: %v", err) } for i := 0; i < 5; i++ { msg := UDSInboxMessage{ Type: "log", Data: fmt.Sprintf("step %d", i), } if err := dial(t, srv.SockPath(), msg); err != nil { t.Fatalf("dial[%d]: %v", i, err) } } received := 0 timeout := time.After(3 * time.Second) for received < 5 { select { case <-srv.Messages(): received++ case <-timeout: t.Fatalf("timeout: only received %d/5 messages", received) } } } // TestUDSServer_InvalidJSON 测试发送无效 JSON 时服务器不崩溃. func TestUDSServer_InvalidJSON(t *testing.T) { srv, err := NewUDSServer("test-session-invalid-json") if err != nil { t.Fatalf("NewUDSServer: %v", err) } defer srv.Close() if err := srv.Start(); err != nil { t.Fatalf("Start: %v", err) } // 发送无效 JSON conn, err := net.Dial("unix", srv.SockPath()) if err != nil { t.Fatalf("dial: %v", err) } conn.Write([]byte("this is not json\n")) //nolint:errcheck conn.Close() // 随后发送合法消息,确认服务器仍然工作 time.Sleep(50 * time.Millisecond) msg := UDSInboxMessage{Type: "log", Data: "after invalid"} if err := dial(t, srv.SockPath(), msg); err != nil { t.Fatalf("dial after invalid json: %v", err) } select { case got := <-srv.Messages(): if got.Data != "after invalid" { t.Errorf("got data=%q, want %q", got.Data, "after invalid") } case <-time.After(2 * time.Second): t.Fatal("timeout: server not recovering after invalid json") } } // TestUDSServer_Close 测试 Close() 后 Messages() channel 被关闭. func TestUDSServer_Close(t *testing.T) { srv, err := NewUDSServer("test-session-close") if err != nil { t.Fatalf("NewUDSServer: %v", err) } if err := srv.Start(); err != nil { t.Fatalf("Start: %v", err) } sockPath := srv.SockPath() srv.Close() // socket 文件应被删除 if _, err := os.Stat(sockPath); !os.IsNotExist(err) { t.Errorf("socket file should be deleted after Close, got: %v", err) } // Messages() channel 应该已关闭(可以读取但最终 drain 后关闭) timeout := time.After(1 * time.Second) for { select { case _, ok := <-srv.Messages(): if !ok { return // channel 正常关闭 } case <-timeout: t.Fatal("messages channel not closed after Close()") } } } // TestUDSServer_CloseIdempotent 测试 Close() 幂等(多次调用不 panic). func TestUDSServer_CloseIdempotent(t *testing.T) { srv, err := NewUDSServer("test-session-idem") if err != nil { t.Fatalf("NewUDSServer: %v", err) } if err := srv.Start(); err != nil { t.Fatalf("Start: %v", err) } // 多次 Close 不应 panic srv.Close() srv.Close() srv.Close() } // TestUDSServer_StartAfterClose 测试 Close 后 Start 返回错误. func TestUDSServer_StartAfterClose(t *testing.T) { srv, err := NewUDSServer("test-session-start-after-close") if err != nil { t.Fatalf("NewUDSServer: %v", err) } if err := srv.Start(); err != nil { t.Fatalf("Start: %v", err) } srv.Close() // Close 后再 Start 应报错 if err := srv.Start(); err == nil { t.Error("Start after Close should return error") } } // TestUDSServer_DoubleStart 测试重复 Start 返回错误. func TestUDSServer_DoubleStart(t *testing.T) { srv, err := NewUDSServer("test-session-double-start") if err != nil { t.Fatalf("NewUDSServer: %v", err) } defer srv.Close() if err := srv.Start(); err != nil { t.Fatalf("first Start: %v", err) } if err := srv.Start(); err == nil { t.Error("second Start should return error") } } // TestUDSServer_WithMeta 测试包含 Meta 字段的消息. func TestUDSServer_WithMeta(t *testing.T) { srv, err := NewUDSServer("test-session-meta") if err != nil { t.Fatalf("NewUDSServer: %v", err) } defer srv.Close() if err := srv.Start(); err != nil { t.Fatalf("Start: %v", err) } metaBytes, _ := json.Marshal(map[string]any{"source": "importer", "row": 42}) msg := UDSInboxMessage{ Type: "result", Data: "import complete", Meta: json.RawMessage(metaBytes), } if err := dial(t, srv.SockPath(), msg); err != nil { t.Fatalf("dial: %v", err) } select { case got := <-srv.Messages(): if got.Type != "result" { t.Errorf("type=%q, want %q", got.Type, "result") } var meta map[string]any if err := json.Unmarshal(got.Meta, &meta); err != nil { t.Fatalf("unmarshal meta: %v", err) } if meta["source"] != "importer" { t.Errorf("meta.source=%v, want importer", meta["source"]) } case <-time.After(2 * time.Second): t.Fatal("timeout") } } // TestNewUDSServer_EmptySessionID 测试空 sessionID 报错. func TestNewUDSServer_EmptySessionID(t *testing.T) { _, err := NewUDSServer("") if err == nil { t.Error("NewUDSServer with empty sessionID should return error") } } // TestUDSServer_SockPathLocation 验证 socket 路径选择策略: // 优先落在用户 cache 目录下,fallback 才是 /tmp,且路径不超过内核限制. func TestUDSServer_SockPathLocation(t *testing.T) { srv, err := NewUDSServer("test-session-location") if err != nil { t.Fatalf("NewUDSServer: %v", err) } defer srv.Close() p := srv.SockPath() // 约束 1: 不超过 104 字节(macOS sockaddr_un.sun_path 上限) if len(p) > maxSockPathLen { t.Errorf("sock path too long: %d > %d: %q", len(p), maxSockPathLen, p) } // 约束 2: 优先位于用户 cache 目录;仅当 UserCacheDir 不可用或越界时才 /tmp. cacheDir, cerr := os.UserCacheDir() if cerr == nil { expectedDir := cacheDir + string(os.PathSeparator) + "flyto" candidate := expectedDir + string(os.PathSeparator) + "flyto-test-session-location.sock" if len(candidate) <= maxSockPathLen { // 路径理论上能放进 cache 目录 - 就不应 fallback 到 /tmp fallback := filepath.Join(os.TempDir(), "flyto-test-session-location.sock") if p == fallback { t.Errorf("sock path should prefer UserCacheDir but got fallback: %q", p) } } } // 约束 3: 真能 listen(端到端验证路径合法) if err := srv.Start(); err != nil { t.Fatalf("Start with resolved path %q: %v", p, err) } if _, err := os.Stat(p); err != nil { t.Errorf("socket file not created at %q: %v", p, err) } }