Skip to content

Commit 1c1f0f8

Browse files
committed
Improved the correctness of TCP (including both Reno and CUBIC variants) by refactoring the current congestion control codebase.
- Rebuilt the congestion-control base into LossBasedCongestionControl, adding shared RFC5681 slow-start/fast-recovery handling, cwnd-in-segments helpers, and loss/event hooks that Reno/CUBIC (and future variants) can override without duplicating bookkeeping (ns/flow/cc.py). TCPReno now simply supplies its congestion-avoidance increment so the standard behavior stays centralized. - Re-implemented CUBIC atop the new base with configurable beta/C, proper epoch resets, fast-convergence bookkeeping, TCP-friendly fallback, and persistent d_min tracking so its window evolution now matches RFC8312 (§4) semantics (ns/flow/cubic.py). - Added regression coverage for Reno slow start and fast recovery as well as CUBIC’s loss/timeout response and cubic growth, plus a tests/conftest.py shim so uv run pytest can import the in-tree ns package without needing an install step (tests/flow/test_tcp_congestion.py, tests/conftest.py).
1 parent 9a1232b commit 1c1f0f8

File tree

5 files changed

+639
-100
lines changed

5 files changed

+639
-100
lines changed

ns/flow/cc.py

Lines changed: 115 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,36 @@
11
"""
2-
The base class for congestion control algorithms, designed to supply the TCPPacketGenerator class
3-
with congestion control decisions.
2+
Shared congestion-control infrastructure for loss-based TCP variants.
3+
4+
The Simulator models cwnd in *bytes* so the helpers provided here enforce RFC 5681/8312
5+
requirements while letting individual algorithms focus on their window update rules.
46
"""
7+
from __future__ import annotations
8+
59
from abc import abstractmethod
10+
from enum import Enum, auto
11+
from typing import Final
12+
13+
14+
class LossEvent(Enum):
15+
"""Enumerates the two canonical loss signals in TCP."""
16+
17+
FAST_LOSS = auto()
18+
TIMEOUT = auto()
619

720

821
class CongestionControl:
922
"""
10-
The base class for congestion control algorithms, designed to supply the TCPPacketGenerator
11-
class with congestion control decisions.
23+
Base class for congestion control algorithms, designed to supply TCPPacketGenerator
24+
with congestion-control decisions.
1225
1326
Parameters
1427
----------
1528
mss: int
16-
the maximum segment size
29+
Maximum segment size in bytes.
1730
cwnd: int
18-
the size of the congestion window.
31+
Congestion window in bytes.
1932
ssthresh: int
20-
the slow start threshold.
33+
Slow-start threshold in bytes.
2134
debug: bool
2235
If True, prints more verbose debug information.
2336
"""
@@ -47,38 +60,113 @@ def ack_received(self, rtt: float = 0, current_time: float = 0):
4760

4861
def timer_expired(self, packet=None):
4962
"""Actions to be taken when a timer expired."""
50-
self.ssthresh = max(2 * self.mss, self.cwnd / 2)
51-
# setting the congestion window to 1 segment
52-
self.cwnd = self.mss
63+
raise NotImplementedError("timer_expired must be implemented by subclasses.")
5364

5465
def dupack_over(self):
5566
"""Actions to be taken when a new ack is received after previous dupacks."""
56-
# RFC 2001 and TCP Reno
57-
self.cwnd = self.ssthresh
67+
raise NotImplementedError("dupack_over must be implemented by subclasses.")
5868

5969
def consecutive_dupacks_received(self, packet=None):
6070
"""Actions to be taken when three consecutive dupacks are received."""
61-
# fast retransmit in RFC 2001 and TCP Reno
62-
self.ssthresh = max(2 * self.mss, self.cwnd / 2)
63-
self.cwnd = self.ssthresh + 3 * self.mss
71+
raise NotImplementedError(
72+
"consecutive_dupacks_received must be implemented by subclasses."
73+
)
6474

6575
def more_dupacks_received(self, packet=None):
6676
"""Actions to be taken when more than three consecutive dupacks are received."""
67-
# fast retransmit in RFC 2001 and TCP Reno
68-
self.cwnd += self.mss
77+
raise NotImplementedError(
78+
"more_dupacks_received must be implemented by subclasses."
79+
)
6980

70-
def set_before_control(self, current_time, packet_in_flight=0):
71-
pass
81+
def cwnd_in_segments(self) -> float:
82+
"""Return the congestion window expressed in number of MSS-sized segments."""
83+
return self.cwnd / self.mss
7284

85+
def min_ssthresh(self) -> float:
86+
"""The RFC 5681-compliant minimum slow-start threshold (2 MSS)."""
87+
return 2 * self.mss
7388

74-
class TCPReno(CongestionControl):
75-
"""TCP Reno, defined in RFC 2001."""
89+
def set_before_control(self, current_time, packet_in_flight: int = 0):
90+
"""Optional hook for controllers that need per-send context (used by BBR)."""
91+
_ = (current_time, packet_in_flight)
92+
93+
94+
class LossBasedCongestionControl(CongestionControl):
95+
"""
96+
Implements the shared bookkeeping for Reno-like algorithms that respond to loss.
97+
98+
The derived classes must provide the congestion-avoidance rule via
99+
:meth:`_congestion_avoidance_ack` and may override the loss hooks if additional
100+
state (e.g., CUBIC's epoch) needs to be updated.
101+
"""
102+
103+
beta: Final[float] = 0.5
104+
beta_timeout: Final[float] = 0.5
76105

77106
def ack_received(self, rtt: float = 0, current_time: float = 0):
78-
"""Actions to be taken when a new ack has been received."""
79-
if self.cwnd <= self.ssthresh:
80-
# slow start
81-
self.cwnd += self.mss
107+
"""RFC 5681 slow start followed by an algorithm-specific avoidance rule."""
108+
if self.cwnd < self.ssthresh:
109+
self._slow_start_ack()
82110
else:
83-
# congestion avoidance
84-
self.cwnd += self.mss * self.mss / self.cwnd
111+
self._congestion_avoidance_ack(rtt, current_time)
112+
113+
def timer_expired(self, packet=None):
114+
"""RFC 5681 timeout handling."""
115+
prev_cwnd = self.cwnd
116+
self.ssthresh = self._ssthresh_after_loss(prev_cwnd, LossEvent.TIMEOUT)
117+
self.cwnd = self.mss # reset to one MSS per RFC 5681 §3.1
118+
self._after_timeout(prev_cwnd, packet)
119+
120+
def dupack_over(self):
121+
"""Exit fast recovery once the lost data is cumulatively acknowledged."""
122+
self.cwnd = self.ssthresh
123+
self._after_fast_recovery_exit()
124+
125+
def consecutive_dupacks_received(self, packet=None):
126+
"""Standard fast retransmit / fast recovery entry."""
127+
prev_cwnd = self.cwnd
128+
self.ssthresh = self._ssthresh_after_loss(prev_cwnd, LossEvent.FAST_LOSS)
129+
# Per RFC 5681 §3.2, inflate the window by 3 segments to keep the ACK clock.
130+
self.cwnd = self.ssthresh + 3 * self.mss
131+
self._after_fast_loss(prev_cwnd, packet)
132+
133+
def more_dupacks_received(self, packet=None):
134+
"""Additional dupacks add one MSS so we clock out a replacement segment."""
135+
self.cwnd += self.mss
136+
self._during_fast_recovery(packet)
137+
138+
def _slow_start_ack(self):
139+
self.cwnd += self.mss
140+
141+
@abstractmethod
142+
def _congestion_avoidance_ack(self, rtt: float, current_time: float):
143+
"""Algorithm-specific congestion avoidance (one cwnd increase per RTT)."""
144+
145+
def _ssthresh_after_loss(self, prev_cwnd: float, event: LossEvent) -> float:
146+
factor = self.beta_timeout if event == LossEvent.TIMEOUT else self.beta
147+
target = prev_cwnd * (1 - factor)
148+
return max(self.min_ssthresh(), target)
149+
150+
def _after_fast_loss(self, prev_cwnd: float, packet=None):
151+
"""Hook for algorithms that maintain extra state on fast loss."""
152+
_ = (prev_cwnd, packet)
153+
154+
def _during_fast_recovery(self, packet=None):
155+
"""Hook invoked for each extra dupack while in fast recovery."""
156+
_ = packet
157+
158+
def _after_fast_recovery_exit(self):
159+
"""Hook invoked when fast recovery completes."""
160+
161+
def _after_timeout(self, prev_cwnd: float, packet=None):
162+
"""Hook invoked after the timeout logic resets cwnd."""
163+
_ = (prev_cwnd, packet)
164+
165+
166+
class TCPReno(LossBasedCongestionControl):
167+
"""TCP Reno as defined in RFC 5681."""
168+
169+
def _congestion_avoidance_ack(self, rtt: float = 0, current_time: float = 0):
170+
"""Additively increase cwnd by roughly one MSS per RTT."""
171+
del rtt, current_time
172+
self.cwnd += (self.mss * self.mss) / self.cwnd

ns/flow/cubic.py

Lines changed: 96 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
"""
2-
The TCP CUBIC congestion control algorithm, used in the Linux kernel since 2.6.19.
2+
TCP CUBIC congestion control (RFC 8312) as used in Linux since v2.6.19.
33
4-
Reference:
5-
6-
Sangtae Ha; Injong Rhee; Lisong Xu. "CUBIC: A New TCP-Friendly High-Speed TCP Variant,"
7-
ACM SIGOPS Operating Systems Review. 42 (5): 64–74, July 2008.
4+
Reference: Sangtae Ha, Injong Rhee, Lisong Xu. "CUBIC: A New TCP-Friendly
5+
High-Speed TCP Variant," ACM SIGOPS OSR, 42(5):64–74, 2008.
86
"""
9-
from ns.flow.cc import CongestionControl
7+
from __future__ import annotations
8+
9+
from ns.flow.cc import LossBasedCongestionControl
1010

1111

12-
class TCPCubic(CongestionControl):
12+
class TCPCubic(LossBasedCongestionControl):
1313
"""
14-
The TCP CUBIC congestion control algorithm, used in the Linux kernel since 2.6.19.
14+
TCP CUBIC congestion control with RFC 8312 compliant window update rules.
1515
1616
Parameters
1717
----------
1818
mss: int
19-
the maximum segment size
19+
The maximum segment size (bytes).
2020
cwnd: int
21-
the size of the congestion window.
21+
The initial congestion window (bytes).
2222
ssthresh: int
23-
the slow start threshold.
23+
Initial slow-start threshold (bytes).
24+
beta: float
25+
Multiplicative decrease factor (default 0.2 per RFC 8312 §4.6).
26+
cubic_constant: float
27+
C in the cubic function W(t) = C(t-K)^3 + W_max (default 0.4).
2428
debug: bool
2529
If True, prints more verbose debug information.
2630
"""
@@ -30,88 +34,107 @@ def __init__(
3034
mss: int = 512,
3135
cwnd: int = 512,
3236
ssthresh: int = 65535,
37+
beta: float = 0.2,
38+
cubic_constant: float = 0.4,
3339
debug: bool = False,
3440
):
3541
super().__init__(mss, cwnd, ssthresh, debug)
36-
self.W_last_max: float = 0
37-
self.epoch_start = 0
38-
self.origin_point = 0
39-
self.d_min: float = 0
40-
self.W_tcp = 0
41-
self.K = 0
42-
self.ack_cnt = 0
42+
self.beta = beta
43+
self.beta_timeout = beta # RFC 8312 uses the same decrease on RTO.
44+
self.cubic_c = cubic_constant
45+
46+
# Internal state is tracked in packets (segments) to mirror RFC notation.
47+
self.W_last_max: float = 0.0
48+
self.epoch_start = 0.0
49+
self.origin_point = 0.0
50+
self.d_min: float = 0.0
51+
self.W_tcp = self.cwnd_in_segments()
52+
self.K = 0.0
53+
self.ack_cnt = 0.0
4354
self.tcp_friendliness = True
4455
self.fast_convergence = True
45-
self.beta = 0.2
46-
self.C = 0.4
47-
self.cwnd_cnt = 0
48-
self.cnt = 0
56+
self.cwnd_cnt = 0.0
57+
self.cnt = float("inf")
4958

5059
def __repr__(self):
5160
return f"cwnd: {self.cwnd}, ssthresh: {self.ssthresh}"
5261

53-
def cubic_reset(self):
54-
"""Resetting the states in CUBIC."""
55-
self.W_last_max = 0
56-
self.epoch_start = 0
57-
self.origin_point = 0
58-
self.d_min = 0
59-
self.W_tcp = 0
60-
self.K = 0
61-
self.ack_cnt = 0
62+
def ack_received(self, rtt: float = 0, current_time: float = 0):
63+
"""Record the minimum RTT and defer to LossBased's slow start logic."""
64+
if rtt > 0:
65+
self.d_min = rtt if self.d_min == 0 else min(self.d_min, rtt)
66+
super().ack_received(rtt, current_time)
67+
68+
def _congestion_avoidance_ack(self, rtt: float, current_time: float):
69+
# Without an RTT sample we fall back to Reno-style additive increase.
70+
if self.d_min <= 0:
71+
self.cwnd += (self.mss * self.mss) / self.cwnd
72+
return
6273

63-
def cubic_update(self, current_time):
64-
"""Updating CUBIC parameters upon the arrival of a new ack."""
74+
self.cubic_update(current_time)
75+
ack_threshold = max(1.0, self.cnt)
76+
self.cwnd_cnt += 1
77+
if self.cwnd_cnt >= ack_threshold:
78+
self.cwnd += self.mss
79+
self.cwnd_cnt = 0
80+
81+
def cubic_update(self, current_time: float):
82+
"""Update the cubic window target (RFC 8312 §4.1)."""
83+
cwnd_packets = self.cwnd_in_segments()
6584
self.ack_cnt += 1
85+
6686
if self.epoch_start <= 0:
6787
self.epoch_start = current_time
68-
if self.cwnd < self.W_last_max:
69-
self.K = ((self.W_last_max - self.cwnd) / self.C) ** (1.0 / 3)
70-
else:
71-
self.K = 0
72-
self.origin_point = self.cwnd
7388
self.ack_cnt = 1
74-
self.W_tcp = self.cwnd
89+
if cwnd_packets < self.W_last_max:
90+
self.K = ((self.W_last_max - cwnd_packets) / self.cubic_c) ** (1.0 / 3.0)
91+
self.origin_point = self.W_last_max
92+
else:
93+
self.K = 0.0
94+
self.W_last_max = cwnd_packets
95+
self.origin_point = cwnd_packets
96+
self.W_tcp = cwnd_packets
97+
7598
t = current_time + self.d_min - self.epoch_start
76-
target = self.origin_point + self.C * (t - self.K) ** 3
77-
if target > self.cwnd:
78-
self.cnt = self.cwnd / (target - self.cwnd)
99+
target = self.origin_point + self.cubic_c * (t - self.K) ** 3
100+
if target > cwnd_packets:
101+
self.cnt = cwnd_packets / max(target - cwnd_packets, 1e-6)
79102
else:
80-
self.cnt = 100 * self.cwnd
103+
self.cnt = 100.0 * max(1.0, cwnd_packets)
104+
81105
if self.tcp_friendliness:
82-
self.cubic_tcp_friendliness()
106+
self.cubic_tcp_friendliness(cwnd_packets)
83107

84-
def cubic_tcp_friendliness(self):
85-
"""CUBIC actions in TCP mode."""
86-
self.W_tcp += 3 * self.beta / (2 - self.beta) * (self.ack_cnt / self.cwnd)
108+
def cubic_tcp_friendliness(self, cwnd_packets: float):
109+
"""TCP-friendly mode keeps pace with Reno while probing above W_max."""
110+
if cwnd_packets <= 0:
111+
return
112+
self.W_tcp += 3 * self.beta / (2 - self.beta) * (self.ack_cnt / cwnd_packets)
87113
self.ack_cnt = 0
88-
if self.W_tcp > self.cwnd:
89-
max_cnt = self.cwnd / (self.W_tcp - self.cwnd)
90-
if self.cnt > max_cnt:
91-
self.cnt = max_cnt
114+
if self.W_tcp > cwnd_packets:
115+
max_cnt = cwnd_packets / max(self.W_tcp - cwnd_packets, 1e-6)
116+
self.cnt = min(self.cnt, max_cnt)
92117

93-
def timer_expired(self, packet=None):
94-
"""Actions to be taken when a timer expired."""
95-
# setting the congestion window to 1 segment
96-
self.cwnd = self.mss
97-
self.cubic_reset()
118+
def _reset_epoch(self):
119+
"""Drop epoch-specific state so the next ACK starts a fresh cubic phase."""
120+
self.epoch_start = 0.0
121+
self.K = 0.0
122+
self.ack_cnt = 0.0
123+
self.cnt = float("inf")
124+
self.cwnd_cnt = 0.0
125+
self.W_tcp = self.cwnd_in_segments()
98126

99-
def ack_received(self, rtt: float = 0, current_time: float = 0):
100-
"""Actions to be taken when a new ack has been received."""
101-
if self.d_min > 0:
102-
self.d_min = min(self.d_min, rtt)
127+
def _after_fast_loss(self, prev_cwnd: float, packet=None):
128+
"""RFC 8312 §4.6 fast convergence bookkeeping."""
129+
cwnd_packets = prev_cwnd / self.mss
130+
if self.fast_convergence and cwnd_packets < self.W_last_max:
131+
self.W_last_max = cwnd_packets * (1 + self.beta) / 2.0
103132
else:
104-
self.d_min = rtt
133+
self.W_last_max = cwnd_packets
134+
self._reset_epoch()
105135

106-
if self.cwnd <= self.ssthresh:
107-
# slow start
108-
self.cwnd += self.mss
109-
else:
110-
# congestion avoidance
111-
self.cubic_update(current_time)
112-
113-
if self.cwnd_cnt > self.cnt:
114-
self.cwnd += self.mss
115-
self.cwnd_cnt = 0
116-
else:
117-
self.cwnd_cnt += 1
136+
def _after_timeout(self, prev_cwnd: float, packet=None):
137+
"""Timeouts force a new epoch but keep the most recent W_max."""
138+
del packet
139+
self.W_last_max = prev_cwnd / self.mss
140+
self._reset_epoch()

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Ensure the project root is importable when tests run from a built wheel."""
2+
from __future__ import annotations
3+
4+
import sys
5+
from pathlib import Path
6+
7+
ROOT = Path(__file__).resolve().parents[1]
8+
if str(ROOT) not in sys.path:
9+
sys.path.insert(0, str(ROOT))

0 commit comments

Comments
 (0)