diff --git a/server/etcdserver/txn/txn.go b/server/etcdserver/txn/txn.go index 220b7d31580..6cfc8717f7d 100644 --- a/server/etcdserver/txn/txn.go +++ b/server/etcdserver/txn/txn.go @@ -43,57 +43,52 @@ func Put(ctx context.Context, lg *zap.Logger, lessor lease.Lessor, kv mvcc.KV, p ) ctx = context.WithValue(ctx, traceutil.TraceKey{}, trace) } - leaseID := lease.LeaseID(p.Lease) - if leaseID != lease.NoLease { - if l := lessor.Lookup(leaseID); l == nil { - return nil, nil, lease.ErrLeaseNotFound - } - } txnWrite := kv.Write(trace) defer txnWrite.End() - resp, err = put(ctx, txnWrite, p) + resp, err = put(ctx, txnWrite, lessor, p) return resp, trace, err } -func put(ctx context.Context, txnWrite mvcc.TxnWrite, p *pb.PutRequest) (resp *pb.PutResponse, err error) { +func put(ctx context.Context, txnWrite mvcc.TxnWrite, lessor lease.Lessor, req *pb.PutRequest) (resp *pb.PutResponse, err error) { trace := traceutil.Get(ctx) resp = &pb.PutResponse{} resp.Header = &pb.ResponseHeader{} - val, leaseID := p.Value, lease.LeaseID(p.Lease) - - var rr *mvcc.RangeResult - if p.IgnoreValue || p.IgnoreLease || p.PrevKv { - trace.StepWithFunction(func() { - rr, err = txnWrite.Range(context.TODO(), p.Key, nil, mvcc.RangeOptions{}) - }, "get previous kv pair") - - if err != nil { - return nil, err - } + prevKV, err := prevKVIfNeeded(ctx, txnWrite, req) + if err != nil { + return nil, err } - if p.IgnoreValue || p.IgnoreLease { - if rr == nil || len(rr.KVs) == 0 { - // ignore_{lease,value} flag expects previous key-value pair - return nil, errors.ErrKeyNotFound - } + val, leaseID := req.Value, lease.LeaseID(req.Lease) + if req.IgnoreValue { + val = prevKV.Value } - if p.IgnoreValue { - val = rr.KVs[0].Value + if req.IgnoreLease { + leaseID = lease.LeaseID(prevKV.Lease) } - if p.IgnoreLease { - leaseID = lease.LeaseID(rr.KVs[0].Lease) + err = checkPut(lessor, req, prevKV) + if err != nil { + return nil, err } - if p.PrevKv { - if rr != nil && len(rr.KVs) != 0 { - resp.PrevKv = &rr.KVs[0] - } + if req.PrevKv { + resp.PrevKv = prevKV } - - resp.Header.Revision = txnWrite.Put(p.Key, val, leaseID) + resp.Header.Revision = txnWrite.Put(req.Key, val, leaseID) trace.AddField(traceutil.Field{Key: "response_revision", Value: resp.Header.Revision}) return resp, nil } +func prevKVIfNeeded(ctx context.Context, rv mvcc.ReadView, req *pb.PutRequest) (*mvccpb.KeyValue, error) { + if req.IgnoreValue || req.IgnoreLease || req.PrevKv { + resp, err := rv.Range(ctx, req.Key, nil, mvcc.RangeOptions{}) + if err != nil { + return nil, err + } + if resp != nil && len(resp.KVs) != 0 { + return &resp.KVs[0], nil + } + } + return nil, nil +} + func DeleteRange(ctx context.Context, lg *zap.Logger, kv mvcc.KV, dr *pb.DeleteRangeRequest) (resp *pb.DeleteRangeResponse, trace *traceutil.Trace, err error) { trace = traceutil.Get(ctx) // create delete tracing if the trace in context is empty @@ -292,7 +287,7 @@ func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWit } else { txnWrite = mvcc.NewReadOnlyTxnWrite(txnRead) } - txnResp, err := txn(ctx, lg, txnWrite, rt, isWrite, txnPath) + txnResp, err := txn(ctx, lg, txnWrite, lessor, rt, isWrite, txnPath) txnWrite.End() trace.AddField( @@ -302,9 +297,9 @@ func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWit return txnResp, trace, err } -func txn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt *pb.TxnRequest, isWrite bool, txnPath []bool) (*pb.TxnResponse, error) { +func txn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, lessor lease.Lessor, rt *pb.TxnRequest, isWrite bool, txnPath []bool) (*pb.TxnResponse, error) { txnResp, _ := newTxnResp(rt, txnPath) - _, err := executeTxn(ctx, lg, txnWrite, rt, txnPath, txnResp) + _, err := executeTxn(ctx, lg, txnWrite, lessor, rt, txnPath, txnResp) if err != nil { if isWrite { // CAUTION: When a txn performing write operations starts, we always expect it to be successful. @@ -356,7 +351,7 @@ func newTxnResp(rt *pb.TxnRequest, txnPath []bool) (txnResp *pb.TxnResponse, txn return txnResp, txnCount } -func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int, err error) { +func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, lessor lease.Lessor, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int, err error) { trace := traceutil.Get(ctx) reqs := rt.Success if !txnPath[0] { @@ -382,7 +377,7 @@ func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt traceutil.Field{Key: "req_type", Value: "put"}, traceutil.Field{Key: "key", Value: string(tv.RequestPut.Key)}, traceutil.Field{Key: "req_size", Value: tv.RequestPut.Size()}) - resp, err := put(ctx, txnWrite, tv.RequestPut) + resp, err := put(ctx, txnWrite, lessor, tv.RequestPut) if err != nil { return 0, fmt.Errorf("applyTxn: failed Put: %w", err) } @@ -396,7 +391,7 @@ func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt respi.(*pb.ResponseOp_ResponseDeleteRange).ResponseDeleteRange = resp case *pb.RequestOp_RequestTxn: resp := respi.(*pb.ResponseOp_ResponseTxn).ResponseTxn - applyTxns, err := executeTxn(ctx, lg, txnWrite, tv.RequestTxn, txnPath[1:], resp) + applyTxns, err := executeTxn(ctx, lg, txnWrite, lessor, tv.RequestTxn, txnPath[1:], resp) if err != nil { // don't wrap the error. It's a recursive call and err should be already wrapped return 0, err @@ -410,19 +405,16 @@ func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt return txns, nil } -func checkPut(rv mvcc.ReadView, lessor lease.Lessor, req *pb.PutRequest) error { - if req.IgnoreValue || req.IgnoreLease { - // expects previous key-value, error if not exist - rr, err := rv.Range(context.TODO(), req.Key, nil, mvcc.RangeOptions{}) - if err != nil { - return err - } - if rr == nil || len(rr.KVs) == 0 { +func checkPut(lessor lease.Lessor, req *pb.PutRequest, prevKV *mvccpb.KeyValue) error { + if req.IgnoreValue || req.IgnoreLease || req.PrevKv { + if (req.IgnoreValue || req.IgnoreLease) && prevKV == nil { + // ignore_{lease,value} flag expects previous key-value pair return errors.ErrKeyNotFound } } - if lease.LeaseID(req.Lease) != lease.NoLease { - if l := lessor.Lookup(lease.LeaseID(req.Lease)); l == nil { + leaseID := lease.LeaseID(req.Lease) + if !req.IgnoreLease && leaseID != lease.NoLease { + if l := lessor.Lookup(leaseID); l == nil { return lease.ErrLeaseNotFound } } @@ -454,7 +446,11 @@ func checkTxn(rv mvcc.ReadView, rt *pb.TxnRequest, lessor lease.Lessor, txnPath case *pb.RequestOp_RequestRange: err = checkRange(rv, tv.RequestRange) case *pb.RequestOp_RequestPut: - err = checkPut(rv, lessor, tv.RequestPut) + prevKV, err := prevKVIfNeeded(context.TODO(), rv, tv.RequestPut) + if err != nil { + return 0, err + } + err = checkPut(lessor, tv.RequestPut, prevKV) case *pb.RequestOp_RequestDeleteRange: case *pb.RequestOp_RequestTxn: txns, err = checkTxn(rv, tv.RequestTxn, lessor, txnPath[1:]) diff --git a/server/etcdserver/txn/txn_test.go b/server/etcdserver/txn/txn_test.go index 850c8a95b9b..01c511d4332 100644 --- a/server/etcdserver/txn/txn_test.go +++ b/server/etcdserver/txn/txn_test.go @@ -98,6 +98,16 @@ var putTestCases = []testCase{ }, }, }, + { + name: "Put withPrevKV should succeed", + op: &pb.RequestOp{ + Request: &pb.RequestOp_RequestPut{ + RequestPut: &pb.PutRequest{ + PrevKv: true, + }, + }, + }, + }, { name: "Put with non-existing lease should fail", op: &pb.RequestOp{