package builtin import ( "context" "database/sql" "encoding/json" "fmt" "strings" "testing" _ "modernc.org/sqlite" // test-only driver; prod build excludes. "git.flytoex.net/yuanwei/flyto-agent/pkg/permission" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" ) // newCASTestDB opens an in-memory sqlite DB, creates an items table // with an integer version column, and seeds one row (id=1, version=0). // // newCASTestDB 打开内存 sqlite, 建带整数 version 列的 items 表, 种入一行 // (id=1, version=0). func newCASTestDB(t *testing.T) StagingDB { t.Helper() db, err := sql.Open("sqlite", ":memory:") if err != nil { t.Fatalf("open sqlite: %v", err) } t.Cleanup(func() { _ = db.Close() }) if _, err := db.Exec(`CREATE TABLE items ( id INTEGER PRIMARY KEY, name TEXT, qty INTEGER, version INTEGER NOT NULL DEFAULT 0 )`); err != nil { t.Fatalf("create: %v", err) } if _, err := db.Exec(`INSERT INTO items (id, name, qty, version) VALUES (1, 'widget', 10, 0)`); err != nil { t.Fatalf("seed: %v", err) } return StagingDB{DB: db} } func casInput(t *testing.T, raw map[string]any) json.RawMessage { t.Helper() b, err := json.Marshal(raw) if err != nil { t.Fatalf("marshal input: %v", err) } return b } func readVersion(t *testing.T, db StagingDB) int64 { t.Helper() var v int64 if err := db.QueryRow("SELECT version FROM items WHERE id=1").Scan(&v); err != nil { t.Fatalf("read version: %v", err) } return v } func TestSQLCAS_HappyPath(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 0) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 20}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if res.IsError { t.Fatalf("expected success, got error: %s", res.Output) } if !strings.Contains(res.Output, "version 0 -> 1") { t.Errorf("expected Output to describe version 0 -> 1, got %q", res.Output) } data, ok := res.Data.(map[string]any) if !ok { t.Fatalf("expected Data map, got %T", res.Data) } if data["new_version"] != int64(1) { t.Errorf("expected new_version=1, got %v", data["new_version"]) } if data["attempts"] != 1 { t.Errorf("expected attempts=1, got %v", data["attempts"]) } if v := readVersion(t, db); v != 1 { t.Errorf("db version: want 1, got %d", v) } } func TestSQLCAS_ConflictRetrySuccess(t *testing.T) { db := newCASTestDB(t) // Pre-conflict: bump version to 5 without going through CAS. if _, err := db.Exec("UPDATE items SET version=5 WHERE id=1"); err != nil { t.Fatalf("pre-bump: %v", err) } tool := NewSQLCASTool(db, 3) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 42}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if res.IsError { t.Fatalf("expected success after retry, got error: %s", res.Output) } data := res.Data.(map[string]any) if data["attempts"] != 2 { t.Errorf("expected attempts=2 (first conflict + second success), got %v", data["attempts"]) } if data["new_version"] != int64(6) { t.Errorf("expected new_version=6 (5+1), got %v", data["new_version"]) } if v := readVersion(t, db); v != 6 { t.Errorf("db version: want 6, got %d", v) } } func TestSQLCAS_MaxRetriesZeroFailsFast(t *testing.T) { db := newCASTestDB(t) if _, err := db.Exec("UPDATE items SET version=5 WHERE id=1"); err != nil { t.Fatalf("pre-bump: %v", err) } tool := NewSQLCASTool(db, 0) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 42}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Fatalf("expected fail-fast error, got success: %s", res.Output) } if !strings.Contains(res.Output, "after 0 retries") { t.Errorf("expected 'after 0 retries' in Output, got %q", res.Output) } if v := readVersion(t, db); v != 5 { t.Errorf("db version must be unchanged (5), got %d", v) } } func TestSQLCAS_MaxRetriesExhausted(t *testing.T) { db := newCASTestDB(t) // BEFORE UPDATE trigger bumps version, so every CAS attempt sees a // changed version and retries. With maxRetries=2 the tool should // exhaust after 3 attempts. // BEFORE UPDATE trigger aborts outer UPDATE via RAISE(IGNORE) and // bumps version via its own nested UPDATE (sqlite defaults to // recursive_triggers=off, so the nested UPDATE does not re-fire the // trigger). Every CAS attempt therefore sees RowsAffected=0 plus a // freshly-bumped version, forcing the retry loop to exhaust. if _, err := db.Exec(`CREATE TRIGGER stomp BEFORE UPDATE ON items WHEN OLD.id = 1 BEGIN UPDATE items SET version = version + 100 WHERE id = 1; SELECT RAISE(IGNORE); END`); err != nil { t.Fatalf("trigger: %v", err) } tool := NewSQLCASTool(db, 2) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 42}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Fatalf("expected retries-exhausted, got success: %s", res.Output) } if !strings.Contains(res.Output, "after 2 retries") { t.Errorf("expected 'after 2 retries' in Output, got %q", res.Output) } } func TestSQLCAS_RowNotFound(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 2) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 999, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 42}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Fatalf("expected row-not-found, got success: %s", res.Output) } if !strings.Contains(res.Output, "no row where") { t.Errorf("expected 'no row where' in Output, got %q", res.Output) } } func TestSQLCAS_InvalidIdentifier(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 0) base := map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 20}, } cases := []struct { name, field, bad string }{ {"table with semicolon", "table", "items;drop"}, {"pk_col with space", "pk_col", "id col"}, {"version_col with dash", "version_col", "version-col"}, {"update_cols key with quote", "update_cols_key", `qty"`}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { input := map[string]any{} for k, v := range base { input[k] = v } if tc.field == "update_cols_key" { input["update_cols"] = map[string]any{tc.bad: 20} } else { input[tc.field] = tc.bad } res, err := tool.Execute(context.Background(), casInput(t, input), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Fatalf("expected reject for %s=%q, got success: %s", tc.field, tc.bad, res.Output) } if !strings.Contains(res.Output, "identifier") { t.Errorf("expected 'identifier' in Output, got %q", res.Output) } }) } } func TestSQLCAS_EmptyUpdateCols(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 0) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Fatalf("expected reject, got success: %s", res.Output) } if !strings.Contains(res.Output, "empty") { t.Errorf("expected 'empty' in Output, got %q", res.Output) } } func TestSQLCAS_MultiColumnUpdate(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 0) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"name": "gadget", "qty": 99}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if res.IsError { t.Fatalf("expected success, got: %s", res.Output) } var name string var qty, version int64 if err := db.QueryRow("SELECT name, qty, version FROM items WHERE id=1").Scan(&name, &qty, &version); err != nil { t.Fatalf("readback: %v", err) } if name != "gadget" || qty != 99 || version != 1 { t.Errorf("readback mismatch: name=%q qty=%d version=%d", name, qty, version) } } func TestSQLCAS_DataTypes(t *testing.T) { cases := []struct { name string value any }{ {"string", "hello"}, {"int", 42}, {"bool true", true}, {"null", nil}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 0) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"name": tc.value}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if res.IsError { t.Fatalf("expected success for %s, got: %s", tc.name, res.Output) } }) } } func TestSQLCAS_VersionZeroToOne(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 0) _, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 1}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if v := readVersion(t, db); v != 1 { t.Errorf("version after CAS: want 1, got %d", v) } } func TestSQLCAS_UnsupportedPKValType(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 0) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": []any{1, 2, 3}, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 1}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Fatalf("expected reject for array pk_val") } if !strings.Contains(res.Output, "pk_val") { t.Errorf("expected 'pk_val' in Output, got %q", res.Output) } } func TestSQLCAS_UnsupportedUpdateValType(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 0) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "items", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": map[string]any{"nested": 1}}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Fatalf("expected reject for object value") } if !strings.Contains(res.Output, "update_cols") { t.Errorf("expected 'update_cols' in Output, got %q", res.Output) } } func TestSQLCAS_NonIntVersionColumn(t *testing.T) { // Declare version as TEXT to force Scan to return string, triggering // the non-int version reject path. db, err := sql.Open("sqlite", ":memory:") if err != nil { t.Fatalf("open: %v", err) } t.Cleanup(func() { _ = db.Close() }) if _, err := db.Exec(`CREATE TABLE things ( id INTEGER PRIMARY KEY, qty INTEGER, version TEXT )`); err != nil { t.Fatalf("create: %v", err) } if _, err := db.Exec(`INSERT INTO things (id, qty, version) VALUES (1, 10, 'v3')`); err != nil { t.Fatalf("seed: %v", err) } tool := NewSQLCASTool(StagingDB{DB: db}, 2) res, err := tool.Execute(context.Background(), casInput(t, map[string]any{ "table": "things", "pk_col": "id", "pk_val": 1, "version_col": "version", "expected_version": 0, "update_cols": map[string]any{"qty": 42}, }), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if !res.IsError { t.Fatalf("expected reject for non-int version, got success: %s", res.Output) } if !strings.Contains(res.Output, "non-int type") { t.Errorf("expected 'non-int type' in Output, got %q", res.Output) } } func TestSQLCAS_ConstructorPanicNilDB(t *testing.T) { defer func() { r := recover() if r == nil { t.Fatal("expected panic on nil db") } if !strings.Contains(fmt.Sprint(r), "db.DB must not be nil") { t.Errorf("unexpected panic message: %v", r) } }() _ = NewSQLCASTool(StagingDB{}, 3) } func TestSQLCAS_ConstructorPanicNegativeRetries(t *testing.T) { defer func() { r := recover() if r == nil { t.Fatal("expected panic on negative maxRetries") } if !strings.Contains(fmt.Sprint(r), ">= 0") { t.Errorf("unexpected panic message: %v", r) } }() db, _ := sql.Open("sqlite", ":memory:") _ = NewSQLCASTool(StagingDB{DB: db}, -1) } func TestSQLCAS_ConstructorPanicExcessiveRetries(t *testing.T) { defer func() { r := recover() if r == nil { t.Fatal("expected panic on excessive maxRetries") } if !strings.Contains(fmt.Sprint(r), "<= 10") { t.Errorf("unexpected panic message: %v", r) } }() db, _ := sql.Open("sqlite", ":memory:") _ = NewSQLCASTool(StagingDB{DB: db}, 11) } func TestSQLCAS_ToolInterface(t *testing.T) { db := newCASTestDB(t) tool := NewSQLCASTool(db, 3) if tool.Name() != "SQLCAS" { t.Errorf("expected Name=SQLCAS, got %q", tool.Name()) } desc := tool.Description(context.Background()) for _, frag := range []string{"optimistic-lock", "STAGING", "3 retries", "integer version"} { if !strings.Contains(desc, frag) { t.Errorf("Description missing %q: %s", frag, desc) } } var schema map[string]any if err := json.Unmarshal(tool.InputSchema(), &schema); err != nil { t.Fatalf("InputSchema JSON: %v", err) } required, _ := schema["required"].([]any) if len(required) != 6 { t.Errorf("expected 6 required fields, got %d", len(required)) } md := tool.Metadata() if md.ReadOnly { t.Error("expected ReadOnly=false") } if !md.ConcurrencySafe { t.Error("expected ConcurrencySafe=true") } if md.PermissionClass != permission.PermClassGeneric { t.Errorf("expected PermissionClass=generic, got %q", md.PermissionClass) } if md.AuditOperation != "edit" { t.Errorf("expected AuditOperation=edit, got %q", md.AuditOperation) } var _ tools.Tool = tool var _ tools.MetadataProvider = tool }