package api import ( "context" "crypto/tls" "net" "net/http" "net/http/httptest" "sync/atomic" "testing" "time" ) // ============================================================ // TransportConfig 测试 // ============================================================ func TestDefaultTransportConfig(t *testing.T) { cfg := DefaultTransportConfig() if cfg.MaxIdleConnsPerHost != 2 { t.Errorf("MaxIdleConnsPerHost = %d, want 2", cfg.MaxIdleConnsPerHost) } if cfg.MaxIdleConns != 10 { t.Errorf("MaxIdleConns = %d, want 10", cfg.MaxIdleConns) } if cfg.IdleConnTimeout != 90*time.Second { t.Errorf("IdleConnTimeout = %v, want 90s", cfg.IdleConnTimeout) } if cfg.TLSHandshakeTimeout != 10*time.Second { t.Errorf("TLSHandshakeTimeout = %v, want 10s", cfg.TLSHandshakeTimeout) } if cfg.DialTimeout != 30*time.Second { t.Errorf("DialTimeout = %v, want 30s", cfg.DialTimeout) } } func TestNewTransport_DefaultConfig(t *testing.T) { tr := NewTransport(nil) if tr == nil { t.Fatal("NewTransport(nil) should return non-nil") } if tr.MaxIdleConnsPerHost != 2 { t.Errorf("MaxIdleConnsPerHost = %d, want 2", tr.MaxIdleConnsPerHost) } if !tr.ForceAttemptHTTP2 { t.Error("ForceAttemptHTTP2 should be true") } } func TestNewTransport_CustomConfig(t *testing.T) { cfg := &TransportConfig{ MaxIdleConnsPerHost: 5, MaxIdleConns: 20, IdleConnTimeout: 60 * time.Second, TLSHandshakeTimeout: 5 * time.Second, DisableKeepAlives: true, TLSConfig: &tls.Config{InsecureSkipVerify: true}, } tr := NewTransport(cfg) if tr.MaxIdleConnsPerHost != 5 { t.Errorf("MaxIdleConnsPerHost = %d, want 5", tr.MaxIdleConnsPerHost) } if tr.MaxIdleConns != 20 { t.Errorf("MaxIdleConns = %d, want 20", tr.MaxIdleConns) } if !tr.DisableKeepAlives { t.Error("DisableKeepAlives should be true") } if tr.TLSClientConfig == nil || !tr.TLSClientConfig.InsecureSkipVerify { t.Error("TLSConfig should be applied") } } func TestNewTransport_CustomDNSResolver(t *testing.T) { resolver := &net.Resolver{} cfg := &TransportConfig{DNSResolver: resolver} tr := NewTransport(cfg) if tr == nil { t.Fatal("should create transport with custom resolver") } } // ============================================================ // Preconnector 测试 // ============================================================ func TestPreconnector_Warmup(t *testing.T) { var hitCount int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&hitCount, 1) if r.Method != "HEAD" { t.Errorf("expected HEAD, got %s", r.Method) } w.WriteHeader(http.StatusOK) })) defer server.Close() tr := NewTransport(nil) p := NewPreconnector(server.URL, tr) p.Warmup(context.Background()) // 等待异步请求完成 time.Sleep(100 * time.Millisecond) if count := atomic.LoadInt32(&hitCount); count != 1 { t.Errorf("expected 1 HEAD request, got %d", count) } } func TestPreconnector_WarmupIdempotent(t *testing.T) { var hitCount int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&hitCount, 1) w.WriteHeader(http.StatusOK) })) defer server.Close() tr := NewTransport(nil) p := NewPreconnector(server.URL, tr) // 多次调用只触发一次 p.Warmup(context.Background()) p.Warmup(context.Background()) p.Warmup(context.Background()) time.Sleep(100 * time.Millisecond) if count := atomic.LoadInt32(&hitCount); count != 1 { t.Errorf("expected exactly 1 request (idempotent), got %d", count) } } func TestPreconnector_FailureSilent(t *testing.T) { // 连接到一个不存在的地址--不应 panic tr := NewTransport(&TransportConfig{DialTimeout: 100 * time.Millisecond}) p := NewPreconnector("http://192.0.2.1:1", tr) // RFC 5737 TEST-NET p.timeout = 200 * time.Millisecond p.Warmup(context.Background()) // 等待超时完成,不应有 panic time.Sleep(500 * time.Millisecond) } func TestPreconnector_SharedTransport(t *testing.T) { // 验证预连接和正式请求共享同一个 Transport 的连接池 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() tr := NewTransport(nil) // 预连接 p := NewPreconnector(server.URL, tr) p.Warmup(context.Background()) time.Sleep(100 * time.Millisecond) // 正式请求用同一个 Transport client := &http.Client{Transport: tr} resp, err := client.Get(server.URL + "/test") if err != nil { t.Fatalf("request failed: %v", err) } resp.Body.Close() // 检查连接池状态--如果复用成功,idle 连接数应 >= 1 // (Go 的 Transport 在请求完成后自动回收连接到池中) } // ============================================================ // DNSCache 测试 // ============================================================ func TestDNSCache_Lookup(t *testing.T) { cache := NewDNSCache(5 * time.Minute) // 解析 localhost(所有系统都应该能解析) addrs, err := cache.Lookup(context.Background(), "localhost") if err != nil { t.Fatalf("lookup localhost failed: %v", err) } if len(addrs) == 0 { t.Error("expected at least one address for localhost") } // 第二次应该命中缓存(验证不报错即可) addrs2, err := cache.Lookup(context.Background(), "localhost") if err != nil { t.Fatalf("cached lookup failed: %v", err) } if len(addrs2) != len(addrs) { t.Errorf("cached result differs: got %d addrs, want %d", len(addrs2), len(addrs)) } } func TestDNSCache_TTLExpiry(t *testing.T) { cache := NewDNSCache(50 * time.Millisecond) // 极短 TTL _, err := cache.Lookup(context.Background(), "localhost") if err != nil { t.Fatalf("first lookup failed: %v", err) } // 等待 TTL 过期 time.Sleep(100 * time.Millisecond) // 应该重新解析(验证不报错即可) _, err = cache.Lookup(context.Background(), "localhost") if err != nil { t.Fatalf("expired lookup failed: %v", err) } } func TestDNSCache_DefaultTTL(t *testing.T) { cache := NewDNSCache(0) // 应该用默认 5 分钟 if cache.ttl != 5*time.Minute { t.Errorf("default TTL = %v, want 5m", cache.ttl) } } func TestDNSCache_Prefetch(t *testing.T) { cache := NewDNSCache(5 * time.Minute) // fire-and-forget,不应 panic cache.Prefetch(context.Background(), "localhost") time.Sleep(100 * time.Millisecond) // 预取后缓存应该有数据 cache.mu.RLock() _, ok := cache.entries["localhost"] cache.mu.RUnlock() if !ok { t.Error("prefetch should populate cache") } } func TestDNSCache_InvalidHost(t *testing.T) { cache := NewDNSCache(5 * time.Minute) _, err := cache.Lookup(context.Background(), "this.host.does.not.exist.invalid") if err == nil { t.Skip("DNS resolver returned address for invalid host, skipping") } } // ============================================================ // WithTransport 测试 // ============================================================ func TestWithTransport(t *testing.T) { tr := NewTransport(nil) client := NewClient("key", "http://test", WithTransport(tr)) if client.httpClient.Transport != tr { t.Error("WithTransport should set the transport on client's httpClient") } }