Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 67 additions & 88 deletions rslearn/data_sources/planetary_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import affine
import numpy.typing as npt
import planetary_computer
import pystac
import pystac_client
import rasterio
import requests
import shapely
Expand All @@ -21,6 +19,7 @@
from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
from rslearn.const import WGS84_PROJECTION
from rslearn.data_sources import DataSource, Item
from rslearn.data_sources.stac_utils import StacApiClient
from rslearn.data_sources.utils import match_candidate_items_to_window
from rslearn.dataset import Window
from rslearn.dataset.materialize import RasterMaterializer
Expand Down Expand Up @@ -81,11 +80,7 @@ class PlanetaryComputer(DataSource, TileStore):
but is not needed.
"""

STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1"

# Default threshold for recreating the STAC client to prevent memory leaks
# from the pystac Catalog's resolved objects cache growing unbounded
DEFAULT_MAX_ITEMS_PER_CLIENT = 1000
STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1/search"

def __init__(
self,
Expand All @@ -97,7 +92,6 @@ def __init__(
timeout: timedelta = timedelta(seconds=10),
skip_items_missing_assets: bool = False,
cache_dir: UPath | None = None,
max_items_per_client: int | None = None,
):
"""Initialize a new PlanetaryComputer instance.

Expand All @@ -114,9 +108,6 @@ def __init__(
cache_dir: optional directory to cache items by name, including asset URLs.
If not set, there will be no cache and instead STAC requests will be
needed each time.
max_items_per_client: number of STAC items to process before recreating
the client to prevent memory leaks from the resolved objects cache.
Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
"""
self.collection_name = collection_name
self.asset_bands = asset_bands
Expand All @@ -126,15 +117,11 @@ def __init__(
self.timeout = timeout
self.skip_items_missing_assets = skip_items_missing_assets
self.cache_dir = cache_dir
self.max_items_per_client = (
max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
)

if self.cache_dir is not None:
self.cache_dir.mkdir(parents=True, exist_ok=True)

self.client: pystac_client.Client | None = None
self._client_item_count = 0
self.client: StacApiClient | None = None

@staticmethod
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "PlanetaryComputer":
Expand All @@ -157,74 +144,73 @@ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "PlanetaryComputer
"query",
"sort_by",
"sort_ascending",
"max_items_per_client",
]
for k in simple_optionals:
if k in d:
kwargs[k] = d[k]

return PlanetaryComputer(**kwargs)

def _load_client(
self,
) -> pystac_client.Client:
"""Lazily load pystac client.
def _load_client(self) -> StacApiClient:
"""Lazily load STAC client.

We don't load it when creating the data source because it takes time and caller
may not be calling get_items. Additionally, loading it during the get_items
call enables leveraging the retry loop functionality in
prepare_dataset_windows.

Note: We periodically recreate the client to prevent memory leaks from the
pystac Catalog's resolved objects cache, which grows unbounded as STAC items
are deserialized and cached. The cache cannot be cleared or disabled.
We don't load it when creating the data source because caller may not be
calling get_items. Additionally, loading it during the get_items call
enables leveraging the retry loop functionality in prepare_dataset_windows.
"""
if self.client is None:
logger.info("Creating initial STAC client")
self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
return self.client

if self._client_item_count < self.max_items_per_client:
return self.client

# Recreate client to clear the resolved objects cache
current_client = self.client
logger.debug(
"Recreating STAC client after processing %d items (threshold: %d)",
self._client_item_count,
self.max_items_per_client,
)
client_root = current_client.get_root()
client_root.clear_links()
client_root.clear_items()
client_root.clear_children()
self._client_item_count = 0
self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
logger.debug("Creating STAC client")
self.client = StacApiClient(
endpoint=self.STAC_ENDPOINT,
timeout=self.timeout.total_seconds(),
)
return self.client

def _stac_item_to_item(self, stac_item: pystac.Item) -> PlanetaryComputerItem:
shp = shapely.geometry.shape(stac_item.geometry)
def _stac_item_to_item(self, stac_item: dict[str, Any]) -> PlanetaryComputerItem:
"""Convert a STAC Item dict to a PlanetaryComputerItem.

# Get time range.
metadata = stac_item.common_metadata
if metadata.start_datetime is not None and metadata.end_datetime is not None:
time_range = (
metadata.start_datetime,
metadata.end_datetime,
)
elif stac_item.datetime is not None:
time_range = (stac_item.datetime, stac_item.datetime)
Args:
stac_item: STAC Item as a dictionary

Returns:
PlanetaryComputerItem instance
"""
shp = shapely.geometry.shape(stac_item["geometry"])

# Get time range from properties
properties = stac_item.get("properties", {})

# Try start_datetime/end_datetime first (for ranges)
# Use dateutil.parser for robust parsing (same as pystac)
from dateutil.parser import parse as parse_datetime

start_dt_str = properties.get("start_datetime")
end_dt_str = properties.get("end_datetime")

if start_dt_str and end_dt_str:
start_dt = parse_datetime(start_dt_str)
end_dt = parse_datetime(end_dt_str)
time_range = (start_dt, end_dt)
elif properties.get("datetime"):
# Single datetime
dt_str = properties["datetime"]
dt = parse_datetime(dt_str)
time_range = (dt, dt)
else:
raise ValueError(
f"item {stac_item.id} unexpectedly missing start_datetime, end_datetime, and datetime"
f"item {stac_item['id']} unexpectedly missing start_datetime, "
"end_datetime, and datetime"
)

geom = STGeometry(WGS84_PROJECTION, shp, time_range)

# Extract asset URLs
asset_urls = {
asset_key: asset_obj.href
for asset_key, asset_obj in stac_item.assets.items()
asset_key: asset_obj["href"]
for asset_key, asset_obj in stac_item.get("assets", {}).items()
}
return PlanetaryComputerItem(stac_item.id, geom, asset_urls)

return PlanetaryComputerItem(stac_item["id"], geom, asset_urls)

def get_item_by_name(self, name: str) -> PlanetaryComputerItem:
"""Gets an item by name.
Expand All @@ -245,27 +231,18 @@ def get_item_by_name(self, name: str) -> PlanetaryComputerItem:
return PlanetaryComputerItem.deserialize(json.load(f))

# No cache or not in cache, so we need to make the STAC request.
logger.debug("Getting STAC item {name}")
logger.debug(f"Getting STAC item {name}")
client = self._load_client()

search_result = client.search(ids=[name], collections=[self.collection_name])
stac_items = list(search_result.items())
stac_item = client.get_item(name, self.collection_name)

if not stac_items:
if stac_item is None:
raise ValueError(
f"Item {name} not found in collection {self.collection_name}"
)
if len(stac_items) > 1:
raise ValueError(
f"Multiple items found for ID {name} in collection {self.collection_name}"
)

stac_item = stac_items[0]
item = self._stac_item_to_item(stac_item)

# Track items processed for client recreation threshold (after deserialization)
self._client_item_count += 1

# Finally we cache it if cache_dir is set.
if cache_fname is not None:
with cache_fname.open("w") as f:
Expand Down Expand Up @@ -293,15 +270,16 @@ def get_items(
# for each requested geometry.
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
logger.debug("performing STAC search for geometry %s", wgs84_geometry)
result = client.search(
collections=[self.collection_name],
intersects=shapely.to_geojson(wgs84_geometry.shp),
datetime=wgs84_geometry.time_range,
query=self.query,

# Collect items from search (client.search yields dict objects)
stac_items = list(
client.search(
collections=[self.collection_name],
intersects=shapely.to_geojson(wgs84_geometry.shp),
datetime=wgs84_geometry.time_range,
query=self.query,
)
)
stac_items = [item for item in result.items()]
# Track items processed for client recreation threshold (after deserialization)
self._client_item_count += len(stac_items)
logger.debug("STAC search yielded %d items", len(stac_items))

if self.skip_items_missing_assets:
Expand All @@ -310,7 +288,8 @@ def get_items(
for stac_item in stac_items:
good = True
for asset_key in self.asset_bands.keys():
if asset_key in stac_item.assets:
assets = stac_item.get("assets", {})
if asset_key in assets:
continue
good = False
break
Expand All @@ -325,7 +304,9 @@ def get_items(

if self.sort_by is not None:
stac_items.sort(
key=lambda stac_item: stac_item.properties[self.sort_by],
key=lambda stac_item: stac_item.get("properties", {}).get(
self.sort_by
),
reverse=not self.sort_ascending,
)

Expand Down Expand Up @@ -639,7 +620,6 @@ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
"query",
"sort_by",
"sort_ascending",
"max_items_per_client",
]
for k in simple_optionals:
if k in d:
Expand Down Expand Up @@ -820,7 +800,6 @@ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1":
"query",
"sort_by",
"sort_ascending",
"max_items_per_client",
]
for k in simple_optionals:
if k in d:
Expand Down
Loading
Loading