Skip to content

Commit f50feb3

Browse files
authored
Merge pull request #486 from stratosphereips/ondra-add-agent-input-checking
Ondra add agent input checking
2 parents 187827e + 71216af commit f50feb3

3 files changed

Lines changed: 73 additions & 7 deletions

File tree

netsecgame/agents/base_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, host, port, role:str)->None:
2222
self._socket.connect((host, port))
2323
except socket.error as e:
2424
self._logger.error(f"Socket error: {e}")
25-
self.sock = None
25+
self._socket = None
2626
self._logger.info("Agent created")
2727

2828
def __del__(self):
@@ -32,7 +32,7 @@ def __del__(self):
3232
self._socket.close()
3333
self._logger.info("Socket closed")
3434
except socket.error as e:
35-
print(f"Error closing socket: {e}")
35+
self._logger.error(f"Error closing socket: {e}")
3636

3737
def terminate_connection(self)->None:
3838
"""Method for graceful termination of connection. Should be used by any class extending the BaseAgent."""
@@ -42,7 +42,7 @@ def terminate_connection(self)->None:
4242
self._socket = None
4343
self._logger.info("Socket closed")
4444
except socket.error as e:
45-
print(f"Error closing socket: {e}")
45+
self._logger.error(f"Error closing socket: {e}")
4646
@property
4747
def socket(self)->socket.socket | None:
4848
return self._socket

netsecgame/game/coordinator.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Optional
66
import signal
77
import os
8+
import re
9+
import uuid
810

911
from netsecgame.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, AgentRole
1012
from netsecgame.game.global_defender import GlobalDefender
@@ -25,6 +27,16 @@ def convert_msg_dict_to_json(msg_dict: dict) -> str:
2527
raise TypeError(f"Error when converting msg to JSON:{e}") from e
2628
return output_message
2729

30+
def sanitize_agent_name(name:str)->str:
31+
"""
32+
Sanitizes the agent name to be used as a filename.
33+
"""
34+
safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', name)
35+
safe_name = re.sub(r'_+', '_', safe_name)
36+
safe_name = safe_name.strip('_')[:200]
37+
if not safe_name:
38+
return f"agent_{uuid.uuid4().hex[:8]}"
39+
return safe_name
2840

2941
class GameCoordinator:
3042
"""
@@ -312,10 +324,26 @@ async def run_game(self):
312324
self.logger.info(f"Coordinator received from agent {agent_addr}: {message}.")
313325

314326
action = self._parse_action_message(agent_addr, message)
315-
if action:
327+
if action is not None:
316328
self._dispatch_action(agent_addr, action)
329+
else:
330+
self._spawn_task(self._respond_on_bad_request, agent_addr, "Malformed Action")
317331
self.logger.info("\tAction processing task stopped.")
318-
332+
333+
async def _respond_on_bad_request(self, agent_addr: tuple, message: str)->None:
334+
"""
335+
Sends a response to the agent indicating that the request was bad.
336+
"""
337+
output_message_dict = {
338+
"to_agent": agent_addr,
339+
"status": str(GameStatus.BAD_REQUEST),
340+
"observation": None,
341+
"message": {
342+
"message": f"Bad request received: {message}",
343+
}
344+
}
345+
await self._agent_response_queues[agent_addr].put(convert_msg_dict_to_json(output_message_dict))
346+
319347
async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None:
320348
"""
321349
Method for processing Action of type ActionType.JoinGame
@@ -327,7 +355,7 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No
327355
try:
328356
self.logger.info(f"New Join request by {agent_addr}.")
329357
if agent_addr not in self.agents:
330-
agent_name = action.parameters["agent_info"].name
358+
agent_name = sanitize_agent_name(str(action.parameters["agent_info"].name))
331359
agent_role = action.parameters["agent_info"].role
332360
if agent_role in AgentRole:
333361
# add agent to the world

tests/game/test_coordinator_core.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,4 +480,42 @@ async def test_run_game_flow(self, mock_coordinator_core):
480480
await mock_coordinator_core.run_game()
481481

482482
mock_parse.assert_called_once_with(agent_addr, valid_json)
483-
mock_dispatch.assert_called_once_with(agent_addr, mock_action)
483+
mock_dispatch.assert_called_once_with(agent_addr, mock_action)
484+
485+
@pytest.mark.asyncio
486+
async def test_run_game_malformed_action(self, mock_coordinator_core):
487+
"""New test for refactored method: run_game flow with malformed action."""
488+
agent_addr = ("127.0.0.1", 12345)
489+
invalid_json = '{"invalid": "json"}'
490+
491+
# Setup queue
492+
mock_coordinator_core._agent_action_queue.get.return_value = (agent_addr, invalid_json)
493+
494+
with patch.object(mock_coordinator_core, '_parse_action_message') as mock_parse, \
495+
patch.object(mock_coordinator_core, '_spawn_task') as mock_spawn:
496+
497+
mock_parse.return_value = None
498+
499+
await mock_coordinator_core.run_game()
500+
501+
mock_parse.assert_called_once_with(agent_addr, invalid_json)
502+
mock_spawn.assert_called_once_with(mock_coordinator_core._respond_on_bad_request, agent_addr, "Malformed Action")
503+
504+
@pytest.mark.asyncio
505+
async def test_respond_on_bad_request(self, mock_coordinator_core):
506+
"""New test for _respond_on_bad_request."""
507+
mock_coordinator_core._respond_on_bad_request = GameCoordinator._respond_on_bad_request.__get__(mock_coordinator_core)
508+
agent_addr = ("127.0.0.1", 12345)
509+
mock_coordinator_core._agent_response_queues = {agent_addr: asyncio.Queue()}
510+
511+
await mock_coordinator_core._respond_on_bad_request(agent_addr, "Malformed Action")
512+
513+
# Ensure the response is in the queue
514+
assert not mock_coordinator_core._agent_response_queues[agent_addr].empty()
515+
516+
response_json = await mock_coordinator_core._agent_response_queues[agent_addr].get()
517+
response_data = json.loads(response_json)
518+
519+
assert response_data["status"] == str(GameStatus.BAD_REQUEST)
520+
assert response_data["observation"] is None
521+
assert "Malformed Action" in response_data["message"]["message"]

0 commit comments

Comments
 (0)