diff --git a/server/etcdserver/txn/txn.go b/server/etcdserver/txn/txn.go index bd2c1eb5fe00..571e2e929d6b 100644 --- a/server/etcdserver/txn/txn.go +++ b/server/etcdserver/txn/txn.go @@ -42,53 +42,24 @@ func Put(ctx context.Context, lg *zap.Logger, lessor lease.Lessor, kv mvcc.KV, p ) ctx = context.WithValue(ctx, traceutil.TraceKey, trace) } - leaseID := lease.LeaseID(p.Lease) - if leaseID != lease.NoLease { - if l := lessor.Lookup(leaseID); l == nil { - return nil, nil, lease.ErrLeaseNotFound - } - } txnWrite := kv.Write(trace) defer txnWrite.End() - resp, err = put(ctx, txnWrite, p) + resp, err = put(ctx, txnWrite, lessor, p) return resp, trace, err } -func put(ctx context.Context, txnWrite mvcc.TxnWrite, p *pb.PutRequest) (resp *pb.PutResponse, err error) { +func put(ctx context.Context, txnWrite mvcc.TxnWrite, lessor lease.Lessor, req *pb.PutRequest) (resp *pb.PutResponse, err error) { trace := traceutil.Get(ctx) resp = &pb.PutResponse{} resp.Header = &pb.ResponseHeader{} - val, leaseID := p.Value, lease.LeaseID(p.Lease) - - var rr *mvcc.RangeResult - if p.IgnoreValue || p.IgnoreLease || p.PrevKv { - trace.StepWithFunction(func() { - rr, err = txnWrite.Range(context.TODO(), p.Key, nil, mvcc.RangeOptions{}) - }, "get previous kv pair") - - if err != nil { - return nil, err - } - } - if p.IgnoreValue || p.IgnoreLease { - if rr == nil || len(rr.KVs) == 0 { - // ignore_{lease,value} flag expects previous key-value pair - return nil, errors.ErrKeyNotFound - } - } - if p.IgnoreValue { - val = rr.KVs[0].Value - } - if p.IgnoreLease { - leaseID = lease.LeaseID(rr.KVs[0].Lease) + val, prevKV, leaseID, err := checkPut(txnWrite, lessor, req) + if err != nil { + return nil, err } - if p.PrevKv { - if rr != nil && len(rr.KVs) != 0 { - resp.PrevKv = &rr.KVs[0] - } + if req.PrevKv { + resp.PrevKv = prevKV } - - resp.Header.Revision = txnWrite.Put(p.Key, val, leaseID) + resp.Header.Revision = txnWrite.Put(req.Key, val, leaseID) trace.AddField(traceutil.Field{Key: "response_revision", Value: resp.Header.Revision}) return resp, nil } @@ -287,7 +258,7 @@ func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWit } else { txnWrite = mvcc.NewReadOnlyTxnWrite(txnRead) } - txnResp, err := txn(ctx, lg, txnWrite, rt, isWrite, txnPath) + txnResp, err := txn(ctx, lg, txnWrite, lessor, rt, isWrite, txnPath) txnWrite.End() trace.AddField( @@ -297,9 +268,9 @@ func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWit return txnResp, trace, err } -func txn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt *pb.TxnRequest, isWrite bool, txnPath []bool) (*pb.TxnResponse, error) { +func txn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, lessor lease.Lessor, rt *pb.TxnRequest, isWrite bool, txnPath []bool) (*pb.TxnResponse, error) { txnResp, _ := newTxnResp(rt, txnPath) - _, err := executeTxn(ctx, lg, txnWrite, rt, txnPath, txnResp) + _, err := executeTxn(ctx, lg, txnWrite, lessor, rt, txnPath, txnResp) if err != nil { if isWrite { // end txn to release locks before panic @@ -350,7 +321,7 @@ func newTxnResp(rt *pb.TxnRequest, txnPath []bool) (txnResp *pb.TxnResponse, txn return txnResp, txnCount } -func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int, err error) { +func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, lessor lease.Lessor, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int, err error) { trace := traceutil.Get(ctx) reqs := rt.Success if !txnPath[0] { @@ -376,7 +347,7 @@ func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt traceutil.Field{Key: "req_type", Value: "put"}, traceutil.Field{Key: "key", Value: string(tv.RequestPut.Key)}, traceutil.Field{Key: "req_size", Value: tv.RequestPut.Size()}) - resp, err := put(ctx, txnWrite, tv.RequestPut) + resp, err := put(ctx, txnWrite, lessor, tv.RequestPut) if err != nil { return 0, fmt.Errorf("applyTxn: failed Put: %w", err) } @@ -390,7 +361,7 @@ func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt respi.(*pb.ResponseOp_ResponseDeleteRange).ResponseDeleteRange = resp case *pb.RequestOp_RequestTxn: resp := respi.(*pb.ResponseOp_ResponseTxn).ResponseTxn - applyTxns, err := executeTxn(ctx, lg, txnWrite, tv.RequestTxn, txnPath[1:], resp) + applyTxns, err := executeTxn(ctx, lg, txnWrite, lessor, tv.RequestTxn, txnPath[1:], resp) if err != nil { // don't wrap the error. It's a recursive call and err should be already wrapped return 0, err @@ -404,23 +375,33 @@ func executeTxn(ctx context.Context, lg *zap.Logger, txnWrite mvcc.TxnWrite, rt return txns, nil } -func checkPut(rv mvcc.ReadView, lessor lease.Lessor, req *pb.PutRequest) error { - if req.IgnoreValue || req.IgnoreLease { - // expects previous key-value, error if not exist - rr, err := rv.Range(context.TODO(), req.Key, nil, mvcc.RangeOptions{}) +func checkPut(rv mvcc.ReadView, lessor lease.Lessor, req *pb.PutRequest) (val []byte, prevKV *mvccpb.KeyValue, leaseID lease.LeaseID, err error) { + val, leaseID = req.Value, lease.LeaseID(req.Lease) + + if req.IgnoreValue || req.IgnoreLease || req.PrevKv { + resp, err := rv.Range(context.TODO(), req.Key, nil, mvcc.RangeOptions{}) if err != nil { - return err + return nil, nil, 0, err + } + if resp != nil && len(resp.KVs) != 0 { + prevKV = &resp.KVs[0] } - if rr == nil || len(rr.KVs) == 0 { - return errors.ErrKeyNotFound + if (req.IgnoreValue || req.IgnoreLease) && prevKV == nil { + // ignore_{lease,value} flag expects previous key-value pair + return nil, nil, 0, errors.ErrKeyNotFound } } - if lease.LeaseID(req.Lease) != lease.NoLease { - if l := lessor.Lookup(lease.LeaseID(req.Lease)); l == nil { - return lease.ErrLeaseNotFound + if req.IgnoreValue { + val = prevKV.Value + } + if req.IgnoreLease { + leaseID = lease.LeaseID(prevKV.Lease) + } else if leaseID != lease.NoLease { + if l := lessor.Lookup(leaseID); l == nil { + return nil, nil, 0, lease.ErrLeaseNotFound } } - return nil + return val, prevKV, leaseID, nil } func checkRange(rv mvcc.ReadView, req *pb.RangeRequest) error { @@ -448,7 +429,7 @@ func checkTxn(rv mvcc.ReadView, rt *pb.TxnRequest, lessor lease.Lessor, txnPath case *pb.RequestOp_RequestRange: err = checkRange(rv, tv.RequestRange) case *pb.RequestOp_RequestPut: - err = checkPut(rv, lessor, tv.RequestPut) + _, _, _, err = checkPut(rv, lessor, tv.RequestPut) case *pb.RequestOp_RequestDeleteRange: case *pb.RequestOp_RequestTxn: txns, err = checkTxn(rv, tv.RequestTxn, lessor, txnPath[1:]) diff --git a/server/etcdserver/txn/txn_test.go b/server/etcdserver/txn/txn_test.go index 2e0ad45534b6..79251cc3d52f 100644 --- a/server/etcdserver/txn/txn_test.go +++ b/server/etcdserver/txn/txn_test.go @@ -97,6 +97,16 @@ var putTestCases = []testCase{ }, }, }, + { + name: "Put withPrevKV should succeed", + op: &pb.RequestOp{ + Request: &pb.RequestOp_RequestPut{ + RequestPut: &pb.PutRequest{ + PrevKv: true, + }, + }, + }, + }, { name: "Put with non-existing lease should fail", op: &pb.RequestOp{