diff --git a/eosimutils/base.py b/eosimutils/base.py index 1b69c38..305d0e2 100644 --- a/eosimutils/base.py +++ b/eosimutils/base.py @@ -74,6 +74,8 @@ class ReferenceFrame(EnumBase): """ - ICRF_EC = "ICRF_EC" # Geocentric Celestial Reference Frame (ECI) + ICRF_EC = ( + "ICRF_EC" # Earth centered inertial frame aligned to the ICRF (ECI) + ) ITRF = "ITRF" # International Terrestrial Reference Frame (ECEF) # TEME = "TEME" # True Equator Mean Equinox diff --git a/eosimutils/state.py b/eosimutils/state.py index be24ada..c560bcc 100644 --- a/eosimutils/state.py +++ b/eosimutils/state.py @@ -79,8 +79,8 @@ def to_list(self) -> List[float]: """ return self.coords.tolist() - @staticmethod - def from_dict(dict_in: Dict[str, Any]) -> "Cartesian3DPosition": + @classmethod + def from_dict(cls, dict_in: Dict[str, Any]) -> "Cartesian3DPosition": """Construct a Cartesian3DPosition object from a dictionary. Args: @@ -99,9 +99,7 @@ def from_dict(dict_in: Dict[str, Any]) -> "Cartesian3DPosition": frame = ( ReferenceFrame.get(dict_in["frame"]) if "frame" in dict_in else None ) - return Cartesian3DPosition( - dict_in["x"], dict_in["y"], dict_in["z"], frame - ) + return cls(dict_in["x"], dict_in["y"], dict_in["z"], frame) def to_dict(self) -> Dict[str, Any]: """Convert the Cartesian3DPosition object to a dictionary. @@ -180,8 +178,8 @@ def to_list(self) -> List[float]: """ return self.coords.tolist() - @staticmethod - def from_dict(dict_in: Dict[str, Any]) -> "Cartesian3DVelocity": + @classmethod + def from_dict(cls, dict_in: Dict[str, Any]) -> "Cartesian3DVelocity": """Construct a Cartesian3DVelocity object from a dictionary. Args: @@ -200,9 +198,7 @@ def from_dict(dict_in: Dict[str, Any]) -> "Cartesian3DVelocity": frame = ( ReferenceFrame.get(dict_in["frame"]) if "frame" in dict_in else None ) - return Cartesian3DVelocity( - dict_in["vx"], dict_in["vy"], dict_in["vz"], frame - ) + return cls(dict_in["vx"], dict_in["vy"], dict_in["vz"], frame) def to_dict(self) -> Dict[str, Any]: """Convert the Cartesian3DVelocity object to a dictionary. @@ -220,7 +216,7 @@ def to_dict(self) -> Dict[str, Any]: class GeographicPosition: """Handles geographic position in the geodetic coordinate system. - The geodetic position is defined with respect to the + The geodetic position is defined with respect to the World Geodetic System 1984 Geoid as defined in Skyfield. Reference: https://rhodesmill.org/skyfield/api-topos.html """ @@ -337,8 +333,8 @@ def __init__( self.velocity: Cartesian3DVelocity = velocity self.frame: ReferenceFrame = frame - @staticmethod - def from_dict(dict_in: Dict[str, Any]) -> "CartesianState": + @classmethod + def from_dict(cls, dict_in: Dict[str, Any]) -> "CartesianState": """Construct a CartesianState object from a dictionary. Args: @@ -360,18 +356,22 @@ def from_dict(dict_in: Dict[str, Any]) -> "CartesianState": ) position = Cartesian3DPosition.from_array(dict_in["position"], frame) velocity = Cartesian3DVelocity.from_array(dict_in["velocity"], frame) - return CartesianState(time, position, velocity, frame) + return cls(time, position, velocity, frame) @staticmethod def from_array( - array_in: Union[List[float], np.ndarray, Tuple[float, float, float]], + array_in: Union[ + List[float], + np.ndarray, + Tuple[float, float, float, float, float, float], + ], time: AbsoluteDate, frame: Optional[Union[ReferenceFrame, str, None]] = None, ) -> "CartesianState": """Construct a CartesianState object from a list, tuple, or NumPy array. Args: - array_in (Union[List[float], np.ndarray, Tuple[float, float, float]]): + array_in (Union[List[float], np.ndarray, Tuple[float, float, float, float, float, float]]): # pylint: disable=line-too-long Position and velocity coordinates in kilometers and km-per-s. time (AbsoluteDate): Absolute date-time object. frame (Union[ReferenceFrame, str, None]): Reference-frame. diff --git a/eosimutils/time.py b/eosimutils/time.py index 67719bc..a9e165c 100644 --- a/eosimutils/time.py +++ b/eosimutils/time.py @@ -212,6 +212,21 @@ def __eq__(self, value): return False return self.ephemeris_time == value.ephemeris_time + def __add__(self, value): + """Add a number of seconds to the AbsoluteDate object. + + Args: + value (float): The number of seconds to add. + + Returns: + AbsoluteDate: A new AbsoluteDate object with the updated time. + """ + return AbsoluteDate(self.ephemeris_time + value) + + def __repr__(self): + """Return a string representation of the AbsoluteDate.""" + return f"AbsoluteDate({self.ephemeris_time})" + class AbsoluteDateArray: """ @@ -364,3 +379,41 @@ def to_dict( ): times_list, "time_scale": str(time_scale), } + + def __len__(self): + """Return the length of the AbsoluteDateArray.""" + return len(self.ephemeris_time) + + def __getitem__(self, index): + """Get an item or a slice from the AbsoluteDateArray. + + Args: + index (int or slice): Index or slice of the item(s) to retrieve. + + Returns: + AbsoluteDate or AbsoluteDateArray: Selected item(s) as AbsoluteDate + or AbsoluteDateArray. + """ + if isinstance(index, slice): + # Handle slicing + return AbsoluteDateArray(self.ephemeris_time[index]) + else: + # Handle single index + return AbsoluteDate(self.ephemeris_time[index]) + + def __eq__(self, value): + """Check equality of two AbsoluteDateArray objects. + + Args: + value (AbsoluteDateArray): The AbsoluteDateArray object to compare with. + + Returns: + bool: True if the objects are equal, False otherwise. + """ + if not isinstance(value, AbsoluteDateArray): + return False + return np.array_equal(self.ephemeris_time, value.ephemeris_time) + + def __repr__(self): + """Return a string representation of the AbsoluteDateArray.""" + return f"AbsoluteDateArray({self.ephemeris_time})" diff --git a/tests/test_time.py b/tests/test_time.py index 60ec115..af196c3 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -2,6 +2,7 @@ import unittest import numpy as np +import copy from astropy.time import Time as Astropy_Time @@ -186,6 +187,33 @@ def test_equality_operator(self): # Test comparison with a non-AbsoluteDate object self.assertFalse(date1 == "not an AbsoluteDate") + def test_add_operator(self): + """Test the __add__ operator for AbsoluteDate.""" + # Initialize an AbsoluteDate object + absolute_date = AbsoluteDate.from_dict( + { + "time_format": "Gregorian_Date", + "calendar_date": "2025-03-17T12:00:00", + "time_scale": "utc", + } + ) + + # Add 3600 seconds (1 hour) + new_date = absolute_date + 3600 + + # Convert the new date to Gregorian format + new_date_dict = new_date.to_dict("Gregorian_Date", "UTC") + + # Expected result after adding 1 hour + expected_date = { + "time_format": "GREGORIAN_DATE", + "calendar_date": "2025-03-17T13:00:00.000", + "time_scale": "UTC", + } + + # Assert the new date matches the expected result + self.assertEqual(new_date_dict, expected_date) + class TestAbsoluteDateArray(unittest.TestCase): """Test the AbsoluteDateArray class.""" @@ -235,6 +263,47 @@ def test_to_dict_and_from_dict(self): rtol=1e-6, ) + def test_length(self): + """Test the length of the AbsoluteDateArray.""" + et_array = np.random.uniform( + 553333629.0, 553333630.0, size=np.random.randint(1, 10) + ) + abs_dates = AbsoluteDateArray(et_array) + self.assertEqual(len(abs_dates), len(et_array)) + + def test_get_item(self): + """Test the __getitem__ method for AbsoluteDateArray.""" + et_array = np.random.uniform( + 553333629.0, 553333635.0, size=np.random.randint(3, 10) + ) + abs_dates = AbsoluteDateArray(et_array) + + # Test getting a single item + item = abs_dates[0] + self.assertIsInstance(item, AbsoluteDate) + self.assertAlmostEqual(item.ephemeris_time, et_array[0], places=6) + + # Test getting a slice + slice_items = abs_dates[1:3] + self.assertIsInstance(slice_items, AbsoluteDateArray) + self.assertEqual(len(slice_items), 2) + + def test_equality_operator(self): + """Test the equality operator for AbsoluteDateArray.""" + et_array1 = np.random.uniform( + 553333633.0, 553333635.0, size=np.random.randint(1, 10) + ) + et_array2 = copy.deepcopy(et_array1) + et_array3 = np.array([553333631.0, 553333632.0]) + + abs_dates1 = AbsoluteDateArray(et_array1) + abs_dates2 = AbsoluteDateArray(et_array2) + abs_dates3 = AbsoluteDateArray(et_array3) + + # Test equality + self.assertTrue(abs_dates1 == abs_dates2) + self.assertFalse(abs_dates1 == abs_dates3) + if __name__ == "__main__": unittest.main()