diff --git a/examples/stream_aggregate_example.go b/examples/stream_aggregate_example.go index fd5894d..f91d90f 100644 --- a/examples/stream_aggregate_example.go +++ b/examples/stream_aggregate_example.go @@ -36,9 +36,11 @@ func main() { docs = append(docs, doc) } - if err := store.InsertMany(collection, docs); err != nil { - log.Printf("Error inserting documents: %v", err) - return + for _, doc := range docs { + if err := store.Insert(collection, doc); err != nil { + log.Printf("Error inserting document: %v", err) + return + } } // 定义聚合管道 diff --git a/internal/engine/crud_handler.go b/internal/engine/crud_handler.go index 6141c1a..24b16e6 100644 --- a/internal/engine/crud_handler.go +++ b/internal/engine/crud_handler.go @@ -54,8 +54,8 @@ func (h *CRUDHandler) Insert(ctx context.Context, collection string, docs []map[ } // Update 更新文档 -func (h *CRUDHandler) Update(ctx context.Context, collection string, filter types.Filter, update types.Update) (*types.UpdateResult, error) { - matched, modified, _, err := h.store.Update(collection, filter, update, false, nil) +func (h *CRUDHandler) Update(ctx context.Context, collection string, filter types.Filter, update types.Update, upsert bool) (*types.UpdateResult, error) { + matched, modified, _, err := h.store.Update(collection, filter, update, upsert, nil) if err != nil { return nil, err } diff --git a/internal/engine/memory_store.go b/internal/engine/memory_store.go index 847d3db..7907dca 100644 --- a/internal/engine/memory_store.go +++ b/internal/engine/memory_store.go @@ -388,6 +388,12 @@ func (ms *MemoryStore) Update(collection string, filter types.Filter, update typ if matched == 0 && upsert { // 创建新文档 newID := generateID() + // 优先使用 filter 中的 _id + if idVal, ok := filter["_id"]; ok { + if idStr, ok := idVal.(string); ok && idStr != "" { + newID = idStr + } + } newDoc := make(map[string]interface{}) // 应用更新($setOnInsert 会生效) diff --git a/internal/protocol/http/server.go b/internal/protocol/http/server.go index 2447d63..fa7f4e8 100644 --- a/internal/protocol/http/server.go +++ b/internal/protocol/http/server.go @@ -271,7 +271,7 @@ func (h *RequestHandler) HandleUpdate(w http.ResponseWriter, r *http.Request, db upserted := make([]types.UpsertID, 0) for _, op := range req.Updates { - result, err := h.crud.Update(context.Background(), fullCollection, op.Q, op.U) + result, err := h.crud.Update(context.Background(), fullCollection, op.Q, op.U, op.Upsert) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return