diff --git a/CHANGELOG.md b/CHANGELOG.md index f0c55c8..d2356c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,16 @@ Changes for the upcoming release can be found (and will be added on merging feat Changes from previous releases are listed below. ## Upcoming Release +- Set all band values to None if any of them is None _(see #104)_ +- Remove entries from shapefile without label before preprocessing _(see #109)_ +- Read only label files for class namt to id dict _(see #110)_ +- Fix numpy/torch version conflict _(see #95)_ +- Add option to filter countries for preprocessing _(see #108)_ +- Add NUTS region identifier _(see #107)_ +- Update collection of Sentinel-2 data to Collection 1 _(see #75)_ +- Fix API URL _(see #94)_ +- Change extension of pickle files _(see #105)_ +- Add data access via S3 bucket _(see #103)_ - Fix country polygons _(see #99)_ - Update to EuroCrops V11 _(see #63)_ - Update country polygons _(see #89)_ diff --git a/eurocropsml/acquisition/builder.py b/eurocropsml/acquisition/builder.py index 68a80a6..bc48757 100644 --- a/eurocropsml/acquisition/builder.py +++ b/eurocropsml/acquisition/builder.py @@ -2,11 +2,12 @@ import logging from pathlib import Path -from typing import cast +from typing import Literal, cast from eurocropsml.acquisition import collector, copier, region from eurocropsml.acquisition.clipping import clipper from eurocropsml.acquisition.config import AcquisitionConfig +from eurocropsml.acquisition.s3 import _set_s3_env_variables from eurocropsml.settings import Settings logger = logging.getLogger(__name__) @@ -49,6 +50,13 @@ def build_dataset( logger.info(f"Processing year {ct_config.year} for {country}.") + if config.eodata_dir is None: + # local_dir = None + _set_s3_env_variables() + source: Literal["eodata", "s3"] = "s3" + else: + source = "eodata" + collector.acquire_sentinel_tiles( ct_config, satellite_output_dir.joinpath("collector"), @@ -58,11 +66,13 @@ def build_dataset( config.workers, ) logger.info("Finished step 1: Acquiring list of necessary .SAFE files.") + copier.merge_safe_files( ct_config.satellite, cast(list[str], ct_config.bands), satellite_output_dir, config.workers, + source, local_dir, ) if local_dir is not None: @@ -70,6 +80,8 @@ def build_dataset( "Finished step 2: Copying .SAFE files to local disk and " "acquiring list of individual band image paths." ) + source = "eodata" + logger.info("Tiles will now be accessed via local storage. Setting `source` to 'eodata'.") else: logger.info("Finished step 2: Acquiring list of individual band image paths.") @@ -80,6 +92,7 @@ def build_dataset( config.workers, config.chunk_size, config.multiplier, + source, local_dir, config.rebuild, ) diff --git a/eurocropsml/acquisition/clipping/clipper.py b/eurocropsml/acquisition/clipping/clipper.py index c1f7ad9..d515a5f 100644 --- a/eurocropsml/acquisition/clipping/clipper.py +++ b/eurocropsml/acquisition/clipping/clipper.py @@ -6,16 +6,21 @@ import logging import multiprocessing as mp_orig import pickle +import sys from functools import partial from pathlib import Path -from typing import cast +from typing import Callable, Literal, cast import geopandas as gpd import pandas as pd import pyogrio from tqdm import tqdm -from eurocropsml.acquisition.clipping.utils import _merge_clipper, mask_polygon_raster +from eurocropsml.acquisition.clipping.utils import ( + _merge_clipper, + mask_polygon_raster, + mask_polygon_raster_s3, +) from eurocropsml.acquisition.config import CollectorConfig logger = logging.getLogger(__name__) @@ -101,15 +106,15 @@ def _get_arguments( clipping_path = output_dir.joinpath("clipper", f"{month}") clipping_path.mkdir(exist_ok=True, parents=True) - if clipping_path.joinpath("args.pkg").exists(): + if clipping_path.joinpath("args.pkl").exists(): logger.info("Loading argument list for parallel raster clipping.") - with open(clipping_path.joinpath("args.pkg"), "rb") as file: + with open(clipping_path.joinpath("args.pkl"), "rb") as file: args: list[tuple[pd.DataFrame, list]] = pickle.load(file) - shapefile: gpd.GeoDataFrame = pd.read_pickle(clipping_path.joinpath("empty_polygon_df.pkg")) + shapefile: gpd.GeoDataFrame = pd.read_pickle(clipping_path.joinpath("empty_polygon_df.pkl")) else: logger.info("No argument list found. Will create it.") # DataFrame of raster file/parcel matches - full_images_paths: Path = output_dir.joinpath("collector", "full_parcel_list.pkg") + full_images_paths: Path = output_dir.joinpath("collector", "full_parcel_list.pkl") full_images = pd.read_pickle(full_images_paths) full_images["completionDate"] = pd.to_datetime(full_images["completionDate"]).dt.date @@ -118,11 +123,13 @@ def _get_arguments( ] if local_dir is not None: - full_images["productIdentifier"] = str(local_dir) + full_images[ - "productIdentifier" - ].astype(str) + full_images["productIdentifier"] = ( + full_images["productIdentifier"] + .astype(str) + .apply(lambda x: str(local_dir.joinpath(x))) + ) - band_image_path: Path = output_dir.joinpath("copier", "band_images.pkg") + band_image_path: Path = output_dir.joinpath("copier", "band_images.pkl") band_images: pd.DataFrame = pd.read_pickle(band_image_path) # filter out month @@ -144,10 +151,12 @@ def _get_arguments( ti.update(n=1) ti.close() - with open(clipping_path.joinpath("args.pkg"), "wb") as fp: + with open(clipping_path.joinpath("args.pkl"), "wb") as fp: pickle.dump(args, fp) logger.info("Saved argument list.") + if sys.stdout is not None: + sys.stdout.flush() date_list = list(full_images["completionDate"].unique()) cols = [parcel_id_name, "geometry"] + date_list @@ -157,7 +166,7 @@ def _get_arguments( shapefile = shapefile.reindex(columns=cols) - shapefile.to_pickle(clipping_path.joinpath("empty_polygon_df.pkg")) + shapefile.to_pickle(clipping_path.joinpath("empty_polygon_df.pkl")) shapefile[parcel_id_name] = shapefile[parcel_id_name].astype(int) @@ -177,6 +186,7 @@ def _filter_args( def _process_raster_parallel( + masking_fct: Callable, polygon_df: pd.DataFrame, parcel_id_name: str, filtered_images: gpd.GeoDataFrame, @@ -185,6 +195,7 @@ def _process_raster_parallel( """Processing one raster file. Args: + masking_fct: Function to use for clipping. Either via S3 or local access. polygon_df: Dataframe containing all parcel ids. Will be merged with the clipped values. parcel_id_name: The country's parcel ID name (varies from country to country). filtered_images: Dataframe containing all parcel ids that lie in this raster tile. @@ -203,7 +214,7 @@ def _process_raster_parallel( # geometry information of all parcels filtered_geom = polygon_df[polygon_df[parcel_id_name].isin(parcel_ids)] - result = mask_polygon_raster(band_tiles, filtered_geom, parcel_id_name, product_date) + result = masking_fct(band_tiles, filtered_geom, parcel_id_name, product_date) result.set_index(parcel_id_name, inplace=True) result.index = result.index.astype(int) # make sure index is integer @@ -219,6 +230,7 @@ def clipping( workers: int, chunk_size: int, multiplier: int, + source: Literal["eodata", "s3"] = "s3", local_dir: Path | None = None, rebuild: bool = False, ) -> None: @@ -231,10 +243,14 @@ def clipping( workers: Maximum number of workers used for multiprocessing. chunk_size: Chunk size used for multiprocessed raster clipping. multiplier: Intermediate results will be saved every multiplier steps. + source: Source of the Sentinel tiles. Either directory ('eodata') or S3 bucket ('s3'). + If files have been copied to a local directory, this was set to 'eodata'. local_dir: Local directory where the .SAFE files were copied to. rebuild: Whether to re-build the clipped parquet files for each month. This will overwrite the existing ones. """ + + masking_fct = mask_polygon_raster_s3 if source == "s3" else mask_polygon_raster for month in tqdm( range(config.months[0], config.months[1] + 1), desc="Clipping rasters on monthly basis" ): @@ -258,7 +274,7 @@ def clipping( clipped_dir.mkdir(exist_ok=True, parents=True) # Process data in smaller chunks - file_counts = len(list(clipped_dir.rglob("Final_*.pkg"))) + file_counts = len(list(clipped_dir.rglob("Final_*.pkl"))) processed = file_counts * multiplier * chunk_size save_files = multiplier * chunk_size @@ -269,6 +285,7 @@ def clipping( ) func = partial( _process_raster_parallel, + masking_fct, polygon_df_month, cast(str, config.parcel_id_name), ) @@ -292,7 +309,10 @@ def clipping( ] results: list[pd.DataFrame] = [] - with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: + with concurrent.futures.ProcessPoolExecutor( + max_workers=max_workers, mp_context=mp_orig.get_context("spawn") + ) as executor: + futures = [executor.submit(func, *arg) for arg in chunk_args] for future in concurrent.futures.as_completed(futures): @@ -310,7 +330,7 @@ def clipping( processed += len(chunk_args) if processed % save_files == 0: - df_final_month.to_pickle(clipped_dir.joinpath(f"Final_{file_counts}.pkg")) + df_final_month.to_pickle(clipped_dir.joinpath(f"Final_{file_counts}.pkl")) del df_final_month df_final_month = polygon_df_month.copy() file_counts += 1 @@ -318,7 +338,7 @@ def clipping( del chunk_args, futures gc.collect() - df_final_month.to_pickle(clipped_dir.joinpath(f"Final_{file_counts}.pkg")) + df_final_month.to_pickle(clipped_dir.joinpath(f"Final_{file_counts}.pkl")) te.close() _merge_dataframe( diff --git a/eurocropsml/acquisition/clipping/utils.py b/eurocropsml/acquisition/clipping/utils.py index 47f12e0..c4174ec 100644 --- a/eurocropsml/acquisition/clipping/utils.py +++ b/eurocropsml/acquisition/clipping/utils.py @@ -1,6 +1,7 @@ """Utilities for clipping polygons from raster tiles.""" import logging +import os from pathlib import Path from typing import cast @@ -14,10 +15,68 @@ from tqdm import tqdm logger = logging.getLogger(__name__) +logging.getLogger("botocore.credentials").setLevel(logging.WARNING) pd.options.mode.chained_assignment = None +def mask_polygon_raster_s3( + tilepaths: list[Path], + polygon_df: pd.DataFrame, + parcel_id_name: str, + product_date: str, +) -> pd.DataFrame: + """Clipping parcels from raster files (per band) and calculating median pixel value per band. + + Args: + tilepaths: Paths to the raster's band tiles. + polygon_df: GeoDataFrame of all parcels to be clipped. + parcel_id_name: The country's parcel ID name (varies from country to country). + product_date: Date on which the raster tile was obtained. + + Returns: + Dataframe with clipped values. + + Raises: + FileNotFoundError: If the raster file cannot be found. + """ + + parcels_dict: dict[int, list[float | None]] = { + parcel_id: [] for parcel_id in polygon_df[parcel_id_name].unique() + } + + # removing any self-intersections or inconsistencies in geometries + polygon_df["geometry"] = polygon_df["geometry"].buffer(0) + polygon_df = polygon_df.reset_index(drop=True) + + s3_container = os.environ.get("S3_CONTAINER_NAME") + for b, band_path in enumerate(tilepaths): + s3_uri = f"/vsis3/{s3_container}/{band_path}" + with rasterio.open(s3_uri, "r") as raster_tile: + if b == 0 and polygon_df.crs.srs != raster_tile.crs: + # transforming shapefile into CRS of raster tile + polygon_df = polygon_df.to_crs(raster_tile.crs) + + # clippping geometry out of raster tile and saving in dictionary + polygon_df.apply( + lambda row: _process_row(row, raster_tile, parcels_dict, parcel_id_name), + axis=1, + ) + + # if any value is None, set all to None + parcels_dict = { + parcel_id: ( + [None] * len(clipped_list) + if any(item is None for item in clipped_list) + else clipped_list + ) + for parcel_id, clipped_list in parcels_dict.items() + } + parcels_df = pd.DataFrame(list(parcels_dict.items()), columns=[parcel_id_name, product_date]) + + return parcels_df + + def mask_polygon_raster( tilepaths: list[Path], polygon_df: pd.DataFrame, @@ -59,6 +118,15 @@ def mask_polygon_raster( axis=1, ) + # if any value is None, set all to None + parcels_dict = { + parcel_id: ( + [None] * len(clipped_list) + if any(item is None for item in clipped_list) + else clipped_list + ) + for parcel_id, clipped_list in parcels_dict.items() + } parcels_df = pd.DataFrame(list(parcels_dict.items()), columns=[parcel_id_name, product_date]) return parcels_df @@ -72,6 +140,11 @@ def _process_row( ) -> None: """Masking geometry from raster tiles and calculating median pixel value.""" parcel_id: int = row[parcel_id_name] + if any(item is None for item in parcels_dict[parcel_id]): + # skip clipping for this parcel_id if any band already produced None + parcels_dict[parcel_id].append(None) + return + geom = row["geometry"] try: diff --git a/eurocropsml/acquisition/collector.py b/eurocropsml/acquisition/collector.py index 4fd17fa..0ea384c 100644 --- a/eurocropsml/acquisition/collector.py +++ b/eurocropsml/acquisition/collector.py @@ -6,7 +6,7 @@ # Email: david.gackstetter@tum.de ##################################################################### # Script majorly revised for EuroCrops by Joana Reuss -# Copyright: Copyright 2024, Technical University of Munich +# Copyright: Copyright 2025, Technical University of Munich # Email: joana.reuss@tum.de ##################################################################### @@ -27,12 +27,18 @@ import pandas as pd import pyogrio import requests +from botocore.client import BaseClient from pyproj import CRS from shapely.geometry.polygon import Polygon from tqdm import tqdm from eurocropsml.acquisition.config import CollectorConfig -from eurocropsml.acquisition.utils import _get_dict_value_by_name, _load_pkg +from eurocropsml.acquisition.s3 import ( + _establish_s3_client, + _get_s3_subfolders, + _parse_s3_xml, +) +from eurocropsml.acquisition.utils import _get_dict_value_by_name, _load_pkl logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -55,7 +61,7 @@ def _eolab_finder( _, num_days = calendar.monthrange(year, months[1]) months_list: list[str] = ["0{0}".format(m) if m < 10 else "{0}".format(m) for m in months] - request_url = """https://datahub.eo-lab.org/odata/v1/Products?$filter=({0}(ContentDate/Start \ + request_url = """https://datahub.creodias.eu/odata/v1/Products?$filter=({0}(ContentDate/Start \ ge {1}-{2}-01T00:00:00.000Z and ContentDate/Start le {1}-{3}-{4}T23:59:59.999Z) and (Online eq \ true) and (OData.CSC.Intersects(Footprint=geography'SRID=4326;{5}')) and (((((Collection/Name eq \ '{6}'){8} and (((Attributes/Odata.CSC.StringAttribute/any(i0:i0/Name eq 'productType' and \ @@ -100,7 +106,7 @@ def acquire_sentinel_tiles( shape_dir: File path of EuroCrops shapefile. shape_dir_clean: Directory where the cleaned shapefile will be stored. eodata_dir: Directory where Sentinel-1 or Sentinel-2 data is stored. - If None, `eodata` is used since this will be returned by the API call. + If None, Sentinel tiles will be accessed via S3 bucket. workers: Maximum number of workers used for multiprocessing. batch_size: Batch size used for multiprocessed merging of .SAFE files and parcels. @@ -154,7 +160,7 @@ def _downloader( ) -> None: request_path = output_dir.joinpath("requests") request_path.mkdir(exist_ok=True, parents=True) - eofinder_request: Path = request_path.joinpath(f"{country}_{year}.json") + eofinder_request: Path = request_path.joinpath(f"{country.replace(' ', '_')}_{year}.json") if not eofinder_request.exists(): # if not eofinder_request_new.exists(): @@ -211,15 +217,26 @@ def _downloader( operational_mode_str, max_requested_products, ) - run_loop = False - logger.info("API-request was successful!") + if requests.get("value") in (None, []): + if collection_name == "SENTINEL-2": + logger.info( + "No products for Collection-1 found. Rerunning request for \ + non-Collection-1 products." + ) + filter_collection = "" + run_loop = True + else: + raise ValueError("API-request was not successful!") + else: + run_loop = False + logger.info("API-request was successful!") except ConnectionError: time.sleep(2000) # Extra loop in case that available products exceed number of maximum requests # These requestes are executed on a monthly basis. if len(requests["value"]) == max_requested_products: - logger.info("Too many requested product. Executing monthly requests.") + logger.info("Too many requested products. Executing monthly requests.") all_months: list = list(range(months[0], months[1] + 1)) for idx, month in enumerate(all_months): run_loop = True @@ -262,12 +279,13 @@ def _downloader( products = requests["value"] - request_files = output_dir.joinpath("requests", "request_safe_files.pkg") + request_files = output_dir.joinpath("requests", "request_safe_files.pkl") max_workers = min(mp_orig.cpu_count(), max(1, min(len(products), workers))) if not request_files.exists(): # creating GeoDataFrame from .SAFE files results: list[list] = [] + with mp_orig.Pool(processes=max_workers) as p: func = partial(_get_tiles, satellite, eodata_dir) process_iter = p.imap(func, products, chunksize=1000) @@ -278,7 +296,17 @@ def _downloader( results.append(result) ti.update(n=1) ti.close() - + if not results: + if eodata_dir is None: + raise AssertionError( + "None of the tiles could be processed. Access to S3 bucket \ + might have failed. Exiting process." + ) + else: + raise AssertionError( + "None of the tiles could be processed. Access to eodata \ + repository might have failed. Exiting process." + ) request_df: pd.DataFrame if satellite == "S2": request_df = pd.DataFrame( @@ -322,15 +350,17 @@ def _downloader( ) for crs in unique_crs ] - + if not request_df_list: + raise AssertionError("None of the tiles could be processed. Exiting process.") if not ( - output_dir.joinpath("full_safe_file_list.pkg").exists() - and output_dir.joinpath("full_parcel_list.pkg").exists() + output_dir.joinpath("full_safe_file_list.pkl").exists() + and output_dir.joinpath("full_parcel_list.pkl").exists() ): if not shape_dir_clean.exists(): # Cleaning up country's shapefile # Load in SHP-File shapefile: gpd.GeoDataFrame = pyogrio.read_dataframe(shape_dir) + shapefile = shapefile[~shapefile["EC_hcat_c"].isna()] if "EC_NUTS3" in shapefile.columns.tolist(): shapefile.drop(["EC_NUTS3"], axis=1) # sort shapefile s.t. NULL classes are at the end @@ -378,13 +408,16 @@ def _downloader( ti.update(n=1) ti.close() - del parcel_df - del shapefile - del request_df_list + if "parcel_df" in locals(): + del parcel_df + if "shapefile" in locals(): + del shapefile + if "request_df_list" in locals(): + del request_df_list with multiprocessing.Pool(processes=max_workers) as p: result_list = list(parcel_path.iterdir()) - process_iter = p.imap(_load_pkg, result_list) + process_iter = p.imap(_load_pkl, result_list) ti = tqdm(total=len(args_list), desc="Loading DataFrames.") for result in process_iter: results.append(result) # type: ignore[arg-type] @@ -413,13 +446,13 @@ def _downloader( combined_result = combined_result.drop_duplicates(subset=subset_cols, keep="first") # saving DataFrame that matches .SAFE files with parcels - combined_result.to_pickle(output_dir.joinpath("full_parcel_list.pkg")) + combined_result.to_pickle(output_dir.joinpath("full_parcel_list.pkl")) unique_safe_files = combined_result["productIdentifier"].unique() # DataFrame of unique .SAFE files safefiles_df = pd.DataFrame({"productIdentifier": unique_safe_files}) - safefiles_df.to_pickle(output_dir.joinpath("full_safe_file_list.pkg")) + safefiles_df.to_pickle(output_dir.joinpath("full_safe_file_list.pkl")) logger.info(f"Finished merging .SAFE files and parcels for {country} for {year}.") @@ -427,12 +460,12 @@ def _downloader( def _process_batch(args: tuple[int, int, gpd.GeoDataFrame, gpd.GeoDataFrame, Path]) -> None: """Checking for intersections between raster tiles and parcel polygons.""" i, batch_size, parcel_df, request_df, parcel_path = args - if not parcel_path.joinpath(f"parcel_list_{i}.pkg").exists(): + if not parcel_path.joinpath(f"parcel_list_{i}.pkl").exists(): batch_parcel_df = parcel_df[i : i + batch_size] result = gpd.sjoin(batch_parcel_df, request_df, how="left", predicate="intersects") result = result[result["index_right"].notna()] result = result.drop(["index_right", "crs", "geometry"], axis=1) - result.to_pickle(parcel_path.joinpath(f"parcel_list_{i}.pkg")) + result.to_pickle(parcel_path.joinpath(f"parcel_list_{i}.pkl")) def _get_tiles( @@ -441,6 +474,7 @@ def _get_tiles( tile: dict, ) -> list | None: """Getting information from raster .SAFE files.""" + safe_file: str = tile["S3Path"] # product Identifier request: list | None @@ -461,11 +495,43 @@ def _get_tiles( if eodata_dir is not None: safe_file = safe_file.replace("eodata", eodata_dir) + safe_file = safe_file.replace("codede", eodata_dir) + try: + granule_path = Path(safe_file).joinpath("GRANULE") + folder: list = list(granule_path.iterdir()) + + tree = ElementTree.parse(folder[0].joinpath("MTD_TL.xml")) + root = tree.getroot() + except Exception: + logger.warning( + "Could not access metadata via eodata directory. \ + This .SAFE file is being skipped." + ) + return None + else: + s3_client: BaseClient | None = _establish_s3_client() + safe_file = safe_file.replace("/eodata/", "") + safe_file = safe_file.replace("/codede/", "") + + try: + granule_path = Path(safe_file).joinpath("GRANULE") + granule_sub_folder: str | None = cast( + list, + _get_s3_subfolders( + s3_client, str(granule_path) + "/", selectionkey="CommonPrefixes" + ), + )[0]["Prefix"] + if granule_sub_folder is not None: + root = _parse_s3_xml(s3_client, granule_sub_folder + "MTD_TL.xml") + else: + return None + except Exception: + logger.warning( + "Could not access metadata via S3 bucket. This .SAFE file is being skipped." + ) + return None + try: - granule_path = Path(safe_file).joinpath("GRANULE") - folder: list = list(granule_path.iterdir()) - tree = ElementTree.parse(folder[0].joinpath("MTD_TL.xml")) - root = tree.getroot() spatial_ref_element: ElementTree.Element = cast( ElementTree.Element, root.find(".//HORIZONTAL_CS_NAME") ) @@ -484,9 +550,10 @@ def _get_tiles( f"The geometry of {safe_file} could not be transformed into a" " shapely Polygon correctly. This .SAFE file is being skipped." ) - request = None + return None else: + # TODO: Sentinel-1 access via S3 bucket try: folder = list(Path(safe_file).iterdir()) @@ -523,6 +590,6 @@ def _get_tiles( f"The geometry of {safe_file} could not be transformed into a" " shapely Polygon correctly. This .SAFE file is being skipped." ) - request = None + return None return request diff --git a/eurocropsml/acquisition/config.py b/eurocropsml/acquisition/config.py index 00ee725..fd8636c 100644 --- a/eurocropsml/acquisition/config.py +++ b/eurocropsml/acquisition/config.py @@ -27,7 +27,23 @@ "12", ] # order is important -S1_BANDS = ["VV", "VH"] # order is important +S1_BANDS = ["VV", "VH"] # order is important{ + +S2_RESOLUTION = { + "01": 60, + "02": 10, + "03": 10, + "04": 10, + "05": 20, + "06": 20, + "07": 20, + "08": 10, + "8A": 20, + "09": 60, + "10": 60, + "11": 20, + "12": 20, +} class CollectorConfig(BaseModel): @@ -74,6 +90,7 @@ class CollectorConfig(BaseModel): shapefile: Path | None = None polygon: str | None = None parcel_id_name: str | None = None + nuts_identifier: str | None = None def post_init(self, vector_data_dir: Path) -> None: """Make dynamic config based on initialized params.""" @@ -161,6 +178,8 @@ def post_init(self, vector_data_dir: Path) -> None: ) self.country_code = cast(str, eurocrops_countries[self.country]["country_code"]) self.ec_filename = cast(str, eurocrops_countries[self.country]["ec_zipfolder"]) + if "nuts" in eurocrops_countries[self.country]: + self.nuts_identifier = cast(str, eurocrops_countries[self.country]["nuts"]) if self.country_code == "ES" and self.year == 2021: filename = f"{self.ec_filename}_2020" @@ -216,7 +235,11 @@ def post_init(self, vector_data_dir: Path) -> None: class AcquisitionConfig(BaseModel): - """Configuration for acquiring EuroCrops reflectance data.""" + """Configuration for acquiring EuroCrops reflectance data. + + If eodata_dir is None, Sentinel tiles will be accessed via S3 bucket. + + """ raw_data_dir: Path output_dir: Path @@ -249,18 +272,42 @@ class EuroCropsCountryConfig(BaseModel): countries: dict[str, dict[str, str | list[str] | list[int]]] = { "Austria": {"country_code": "AT", "ec_zipfolder": "AT", "years": [2021]}, - "Belgium VLG": {"country_code": "BE", "ec_zipfolder": "BE_VLG", "years": [2021]}, - "Belgium WAL": {"country_code": "BE", "ec_zipfolder": "BE_WAL", "years": [2021]}, + "Belgium VLG": { + "country_code": "BE", + "ec_zipfolder": "BE_VLG", + "years": [2021], + "nuts": "BE2", + }, + "Belgium WAL": { + "country_code": "BE", + "ec_zipfolder": "BE_WAL", + "years": [2021], + "nuts": "BE3", + }, "Croatia": {"country_code": "HR", "ec_zipfolder": "HR", "years": [2020]}, "Czechia": {"country_code": "CZ", "ec_zipfolder": "CZ", "years": [2023]}, "Denmark": {"country_code": "DK", "ec_zipfolder": "DK", "years": [2019]}, "Estonia": {"country_code": "EE", "ec_zipfolder": "EE", "years": [2021]}, "Finland": {"country_code": "FI", "ec_zipfolder": "FI", "years": [2020]}, "France": {"country_code": "FR", "ec_zipfolder": "FR", "years": [2018]}, - "Germany LS": {"country_code": "DE", "ec_zipfolder": "DE_LS", "years": [2021]}, - "Germany NRW": {"country_code": "DE", "ec_zipfolder": "DE_NRW", "years": [2021]}, - "Germany BB": {"country_code": "DE", "ec_zipfolder": "DE_BB", "years": [2023]}, - "Ireland": {"country_code": "IE", "ec_zipfolder": "IE", "years": [2023]}, + "Germany LS": { + "country_code": "DE", + "ec_zipfolder": "DE_LS", + "years": [2021], + "nuts": "DE9", + }, + "Germany NRW": { + "country_code": "DE", + "ec_zipfolder": "DE_NRW", + "years": [2021], + "nuts": "DEA", + }, + "Germany BB": { + "country_code": "DE", + "ec_zipfolder": "DE_BB", + "years": [2023], + "nuts": "DE4", + }, "Latvia": {"country_code": "LV", "ec_zipfolder": "LV", "years": [2021]}, "Lithuania": {"country_code": "LT", "ec_zipfolder": "LT", "years": [2021]}, "Netherlands": {"country_code": "NL", "ec_zipfolder": "NL", "years": [2020]}, @@ -285,7 +332,7 @@ class EuroCropsCountryConfig(BaseModel): "France": "ID_PARCEL", "Germany LS": "", # no unique identifier "Germany NRW": "ID", - "Germany BB": "", + "Germany BB": "", # no unique identifier "Ireland": "", # no unique identifier "Latvia": "PARCEL_ID", "Lithuania": "parcel_id", diff --git a/eurocropsml/acquisition/copier.py b/eurocropsml/acquisition/copier.py index 47d9d33..0439b99 100644 --- a/eurocropsml/acquisition/copier.py +++ b/eurocropsml/acquisition/copier.py @@ -9,32 +9,66 @@ from typing import Literal, cast import pandas as pd +from botocore.client import BaseClient from tqdm import tqdm +from eurocropsml.acquisition.config import S2_RESOLUTION +from eurocropsml.acquisition.s3 import ( + _download_s3_prefix, + _establish_s3_client, + _get_s3_subfolders, +) + logger = logging.getLogger(__name__) -def _copy_to_local_dir(local_dir: Path, safe_file: str) -> None: +def _copy_to_local_dir( + source: Literal["eodata", "s3"], local_dir: Path, safe_file: pd.Series +) -> None: """Copying files to local directory. Args: - local_dir: Directory to copy the file to. + source: Source of the Sentinel tiles. Either directory ('eodata') or S3 bucket ('s3'). + local_dir: Local directory where the .SAFE files are copied to. safe_file: File to copy to local directory. """ # Copy all image files from network storage over to local storage - filename = safe_file[1] - local_product: Path = local_dir.joinpath(filename[1:]) - local_parent_dir: Path = local_product.parents[0] - if not local_parent_dir.exists(): - local_parent_dir.mkdir(exist_ok=True, parents=True) + safe_file_name = safe_file[1] + + if source == "eodata": + # TODO: check for correctne + local_product: Path = local_dir.joinpath(safe_file_name.lstrip("/")) + granule_folder = Path(safe_file_name) / "GRANULE" + granule_sub_folder = list(granule_folder.iterdir())[0] + img_data_path = granule_sub_folder / "IMG_DATA" + local_parent_dir: Path = local_product / "GRANULE" / granule_sub_folder.name / "IMG_DATA" + if not local_parent_dir.exists(): + local_parent_dir.mkdir(exist_ok=True, parents=True) + + for jp2_file in img_data_path.glob(".jp2"): + if not jp2_file.exists(): + shutil.copy2(safe_file_name, jp2_file) + else: + s3_client: BaseClient = _establish_s3_client() + granule_prefix = f"{safe_file_name}/GRANULE/" + granule_subfolders: list = cast( + list, _get_s3_subfolders(s3_client, granule_prefix, selectionkey="CommonPrefixes") + )[0]["Prefix"] + img_data_folder: str = f"{granule_subfolders}IMG_DATA/" + local_product = local_dir.joinpath(img_data_folder) - if not local_product.exists(): - shutil.copytree(filename, local_product) + if not local_product.exists(): + local_product.mkdir(parents=True, exist_ok=True) + _download_s3_prefix(s3_client, img_data_folder, local_product, file_extension=".jp2") def _get_image_files( - full_safe_files: pd.DataFrame, satellite: Literal["S1", "S2"], bands: list[str] + full_safe_files: pd.DataFrame, + satellite: Literal["S1", "S2"], + bands: list[str], + source: Literal["eodata", "s3"] = "s3", + local_dir: str = "", ) -> pd.DataFrame: """Getting paths for each spectral band. @@ -42,36 +76,40 @@ def _get_image_files( full_safe_files: DataFrame with .SAFE file paths for which to get the band paths. satellite: S1 for Sentinel-1 and S2 for Sentinel-2. bands: (Sub-)set of Sentinel-1 (radar) or Sentinel-2 (spectral) bands. + source: Source of the Sentinel tiles. Either directory ('eodata') or S3 bucket ('s3') + local_dir: Local directory where the .SAFE files are copied to. + If None, .SAFE files will not be stored on local disk. Returns: DataFrame with band paths as columns. """ - image_update = pd.DataFrame(columns=full_safe_files.columns.values) - for idx, row in tqdm( - full_safe_files.iterrows(), - total=len(full_safe_files), - desc="Collecting paths for spectral bands.", - ): - if Path(row["productIdentifier"]).exists(): - filename_list: list - files: list[str] - - if satellite == "S2": - filename_list = os.listdir(os.path.join(row["productIdentifier"], "GRANULE")) - filename: str = filename_list[0] - sub_path: str = os.path.join("GRANULE", filename, "IMG_DATA") - path_files: str = os.path.join(row["productIdentifier"], sub_path) - resolutions: list[int] = [10, 20, 60] - - files = os.listdir(path_files) - - for band in bands: - if "R10m" in files: - for res in resolutions: + image_update = pd.DataFrame() + + if source == "eodata": + for _, row in tqdm( + full_safe_files.iterrows(), + total=len(full_safe_files), + desc="Collecting paths for spectral bands.", + ): + if Path(row["productIdentifier"]).exists(): + filename_list: list + files: list[str] + + if satellite == "S2": + filename_list = os.listdir(os.path.join(row["productIdentifier"], "GRANULE")) + filename: str = filename_list[0] + sub_path: str = os.path.join("GRANULE", filename, "IMG_DATA") + path_file: str = os.path.join(row["productIdentifier"], sub_path) + + files = os.listdir(path_file) + + for band in bands: + if "R10m" in files: + res = S2_RESOLUTION[band] r1_files: list[str] = os.listdir( - os.path.join(path_files, "R{0}m".format(res)) + os.path.join(path_file, "R{0}m".format(res)) ) image_found: list[str] = [ file for file in r1_files if f"_B{band}" in file @@ -81,33 +119,112 @@ def _get_image_files( sub_path, "R{0}m".format(res), image_found[0] ) image_found = [] - break - else: - image_found = [file for file in files if f"_B{band}" in file] - row["bandImage_{0}".format(band)] = os.path.join(path_files, image_found[0]) + else: + image_found = [file for file in files if f"_B{band}" in file] + row["bandImage_{0}".format(band)] = os.path.join( + path_file, image_found[0] + ) + image_found = [] + + else: + path_file = os.path.join(row["productIdentifier"], "measurement") + + files = os.listdir(path_file) + + for i in range(len(bands)): + image_found = [file for file in files if f"{bands[i].lower()}" in file] + if image_found: + row["bandImage_{0}".format(bands[i])] = os.path.join( + path_file, image_found[0] + ) image_found = [] + row_df: pd.DataFrame = row.to_frame().T + if image_update.empty: + image_update = row_df + else: + image_update = pd.concat([image_update, row_df], ignore_index=True) + + else: + s3_client: BaseClient = _establish_s3_client() + for _, row in tqdm( + full_safe_files.iterrows(), + total=len(full_safe_files), + desc="Collecting paths for spectral bands from S3 bucket.", + ): + + prefix = row["productIdentifier"] + prefix = prefix.replace(local_dir, "") + granule_prefix = f"{prefix.lstrip('/')}/GRANULE/" + + granule_folder = _get_s3_subfolders( + s3_client, granule_prefix, selectionkey="CommonPrefixes" + ) + if not granule_folder: + logger.info(f"Access to {granule_prefix} failed. Skipping .SAFE file.") + continue + + sub_folder = cast(list[dict], granule_folder)[0]["Prefix"] + img_data_prefix = f"{sub_folder}IMG_DATA/" + + img_data_prefixes = _get_s3_subfolders( + s3_client, img_data_prefix, selectionkey="Contents" + ) + if not img_data_prefixes: + logger.info(f"Access to {img_data_prefix} failed. Skipping .SAFE file.") + continue else: - path_files = os.path.join(row["productIdentifier"], "measurement") + img_data_prefixes = cast(list[dict], img_data_prefixes) + # check if 'R10m/' is present in the list of sub-prefixes + is_structured = any("R10m/" in p["Key"] for p in img_data_prefixes) + + for band in bands: + res = S2_RESOLUTION[band] + if is_structured: + # structured format (R10m, R20m, R60m) + res_dir = f"R{res}m/" + # search path includes the resolution folder + search_prefix = f"{img_data_prefix}{res_dir}" + + # list the files inside the resolution folder + band_files = _get_s3_subfolders( + s3_client, search_prefix, selectionkey="Contents" + ) + if not band_files: + logger.info(f"Access to {search_prefix} failed. Skipping .SAFE file.") + continue + else: + band_files = cast(list[dict], band_files) + # filter for specific band file (T..._B04.jp2) + image_found = [ + file["Key"] for file in band_files if f"_B{band}.jp2" in file["Key"] + ] + if image_found: + row["bandImage_{0}".format(band)] = image_found[0] + image_found = [] - files = os.listdir(path_files) + else: + # FLAT FORMAT (Older products, or error fallback) + search_prefix = img_data_prefix - for i in range(len(bands)): - image_found = [file for file in files if f"{bands[i].lower()}" in file] - if image_found: - row["bandImage_{0}".format(bands[i])] = os.path.join( - path_files, image_found[0] - ) - image_found = [] + # List files in IMG_DATA + band_files = cast( + list, _get_s3_subfolders(s3_client, search_prefix, selectionkey="Contents") + ) - if idx == 0: - image_update = row.to_frame() - image_update = image_update.T + image_found = [ + file["Key"] for file in band_files if f"_B{band}.jp2" in file["Key"] + ] + if image_found: + row["bandImage_{0}".format(band)] = image_found[0] + image_found = [] + + row_df = row.to_frame().T + if image_update.empty: + image_update = row_df else: - row_df: pd.DataFrame = row.to_frame() - row_df = row_df.T image_update = pd.concat([image_update, row_df], ignore_index=True) return image_update @@ -118,6 +235,7 @@ def merge_safe_files( bands: list[str], output_dir: Path, workers: int, + source: Literal["eodata", "s3"] = "s3", local_dir: Path | None = None, ) -> None: """Copy all relevant .SAFE files to local directory and acquire spectral band paths. @@ -128,25 +246,29 @@ def merge_safe_files( output_dir: Directory where lists of required .SAFE files (per parcel id) are stored and where to save the output files to. workers: Maximum number of workers to use for multiprocessing. + source: Source of the Sentinel tiles. Either directory ('eodata') or S3 bucket ('s3') local_dir: Local directory where the .SAFE files are copied to. If None, .SAFE files will not be stored on local disk. """ - safe_df = pd.read_pickle(output_dir.joinpath("collector", "full_safe_file_list.pkg")) + safe_df = pd.read_pickle(output_dir.joinpath("collector", "full_safe_file_list.pkl")) + if local_dir is not None: + local_dir = cast(Path, local_dir) # list of unique .SAFE files identifiers full_safe_files: pd.Series = safe_df["productIdentifier"] if local_dir is not None: # Copying the .SAFE files to a local directory massively fastens up the process of opening # them later on. Furthermore, opening them directly on the external directory sometimes led - # to the directory disconnecting from the VM. - local_dir = cast(Path, local_dir) + # to the directory disconnecting from the VM. The same happend for the S3 connection. logger.info("Copying files to local storage.") max_workers = min(mp_orig.cpu_count(), max(1, min(len(full_safe_files), workers))) + # for row in full_safe_files.items(): + # _copy_to_local_dir(source, local_dir, row) with mp_orig.Pool(processes=max_workers) as p: - func = partial(_copy_to_local_dir, local_dir) + func = partial(_copy_to_local_dir, source, local_dir) process_iter = p.imap(func, full_safe_files.items()) ti = tqdm(total=len(full_safe_files), desc="Copying .SAFE files to local disk.") _ = [ti.update(n=1) for _ in process_iter] @@ -155,14 +277,15 @@ def merge_safe_files( logger.info(f"Finished copying all files to local directory {local_dir}.") local_safe_files: list[str] = [ - str(local_dir.joinpath(file[1:])) for file in full_safe_files + str(local_dir.joinpath(file.lstrip("/"))) for file in full_safe_files ] safe_files_df: pd.DataFrame = pd.DataFrame(local_safe_files, columns=["productIdentifier"]) + source = "eodata" else: safe_files_df = pd.DataFrame(full_safe_files.tolist(), columns=["productIdentifier"]) copier_path: Path = output_dir.joinpath("copier") - band_path: Path = copier_path.joinpath("band_images.pkg") + band_path: Path = copier_path.joinpath("band_images.pkl") # Collecting all .jp2-paths for each .SAFE file. if band_path.is_file(): @@ -172,11 +295,22 @@ def merge_safe_files( band_images_exist["productIdentifier"] ) safe_files_df.drop(safe_files_df[remove_rows].index, inplace=True) - new_band_images: pd.DataFrame = _get_image_files(safe_files_df, satellite, bands) - band_images: pd.DataFrame = pd.concat([band_images_exist, new_band_images]) + if not safe_files_df.empty: + new_band_images: pd.DataFrame = _get_image_files( + safe_files_df, + satellite, + bands, + source, + str(local_dir) if local_dir is not None else "", + ) + band_images: pd.DataFrame = pd.concat([band_images_exist, new_band_images]) + band_images.to_pickle(copier_path.joinpath("band_images.pkl")) + logger.info(f"Saved band images to {copier_path.joinpath('band_images.pkl')}.") else: copier_path.mkdir(exist_ok=True, parents=True) - band_images = _get_image_files(safe_files_df, satellite, bands) + band_images = _get_image_files( + safe_files_df, satellite, bands, source, str(local_dir) if local_dir is not None else "" + ) - band_images.to_pickle(copier_path.joinpath("band_images.pkg")) - logger.info(f"Saved band images to {copier_path.joinpath('band_images.pkg')}.") + band_images.to_pickle(copier_path.joinpath("band_images.pkl")) + logger.info(f"Saved band images to {copier_path.joinpath('band_images.pkl')}.") diff --git a/eurocropsml/acquisition/region.py b/eurocropsml/acquisition/region.py index fef1962..6a11a33 100644 --- a/eurocropsml/acquisition/region.py +++ b/eurocropsml/acquisition/region.py @@ -92,6 +92,9 @@ def add_nuts_regions( nuts_df = nuts[nuts["CNTR_CODE"] == config.country_code] + if config.nuts_identifier is not None: + nuts_df = nuts[nuts["NUTS_ID"].str.startswith(config.nuts_identifier)] + parcel_id_name: str = cast(str, config.parcel_id_name) cols_shapefile = [parcel_id_name, "geometry", "EC_hcat_n", "EC_hcat_c", "nuts1"] @@ -190,7 +193,7 @@ def add_nuts_regions( if ( x is None or (isinstance(x, float) and pd.isna(x)) - or (isinstance(x, list) and all(pd.isna(val) for val in x)) + or (isinstance(x, list) and any(pd.isna(val) for val in x)) ) else x ) diff --git a/eurocropsml/acquisition/s3.py b/eurocropsml/acquisition/s3.py new file mode 100644 index 0000000..ce50f5b --- /dev/null +++ b/eurocropsml/acquisition/s3.py @@ -0,0 +1,134 @@ +"""Acquiring data from S3 bucket.""" + +import io +import logging +import os +import xml.etree.ElementTree as ElementTree +from pathlib import Path +from typing import Literal, cast + +import boto3 +from botocore.client import BaseClient + +from eurocropsml.settings import CONTAINER, ENDPOINT_URL, REGION_NAME, Settings + +logger = logging.getLogger(__name__) + + +def _set_s3_env_variables() -> None: + credentials = _get_s3_credentials() + + os.environ["AWS_ACCESS_KEY_ID"] = credentials[0] + os.environ["AWS_SECRET_ACCESS_KEY"] = credentials[1] + os.environ["AWS_S3_ENDPOINT"] = ENDPOINT_URL.replace("https://", "").replace("http://", "") + os.environ["AWS_VIRTUAL_HOSTING"] = "false" + os.environ["AWS_REGION"] = REGION_NAME + os.environ["S3_CONTAINER_NAME"] = CONTAINER + os.environ["CPL_VSIS3_READ_TIMEOUT"] = "60" + + +def _get_s3_credentials(file_name: str = "eodata-access") -> tuple[str, str]: + + cfg_dir: Path = Settings().cfg_dir + file_path: Path = cfg_dir / file_name + if not file_path.exists(): + raise FileNotFoundError( + f"{file_path} was not found. \ + Please first create a file with your EC2 credentials with 'ACCESS_KEY:SECURITY_KEY'." + ) + + with open(file_path, "r") as f: + credentials = [line.split(":") for line in f][0] + + return credentials[0], credentials[1].strip("\n") + + +def _establish_s3_client() -> BaseClient: + access_key = os.environ.get("AWS_ACCESS_KEY_ID") + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") + region = os.environ.get("AWS_REGION") + + s3_client = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, + endpoint_url=ENDPOINT_URL, + ) + + return s3_client + + +def _get_s3_subfolders( + s3_client: BaseClient, prefix: str, selectionkey: str | None = None +) -> str | list[dict] | None: + + try: + response: str | dict = s3_client.list_objects_v2( + Bucket=os.environ.get("S3_CONTAINER_NAME"), Prefix=prefix, Delimiter="/" + ) + if selectionkey is not None and isinstance(response, dict): + return cast(list, response[selectionkey]) + return cast(str | list, response) + except ValueError: + return None + + +def _parse_s3_xml(s3_client: BaseClient, key: str) -> ElementTree.Element: + """Retrieves an XML file from S3 and parses it with ElementTree.""" + + # get object from S3 + s3_response = s3_client.get_object(Bucket=os.environ.get("S3_CONTAINER_NAME"), Key=key) + + # read the body content (bytes) + xml_bytes = s3_response["Body"].read() + + # create a BytesIO buffer (in-memory file) from the bytes + xml_stream = io.BytesIO(xml_bytes) + + # ElementTree.parse() to read from the in-memory file stream + tree = ElementTree.parse(xml_stream) + + return tree.getroot() + + +def _download_s3_prefix( + s3_client: BaseClient, + s3_prefix: str, + local_base_dir: Path, + file_extension: Literal[".jp2"] = ".jp2", +) -> None: + """Recursively downloads all files under a given S3 prefix. + + Files are downloaded to a local directory, maintaining the relative path structure. + """ + + # Iterate over all pages of results + response = _get_s3_subfolders(s3_client, s3_prefix, selectionkey="Contents") + + if isinstance(response, list): + # Iterate over every object (file) found under the prefix + for obj in response: + s3_key = obj["Key"] + if s3_key.endswith(file_extension): + # Construct the local file path by appending the relative key + # to the local base directory. + # Example: If local_base_dir is /cache/S2B... and s3_prefix is S2B.../IMG_DATA/ + # and s3_key is S2B.../IMG_DATA/B01.jp2, local_file will be + # /cache/S2B.../IMG_DATA/B01.jp2 + file_name = Path(s3_key).relative_to(Path(s3_prefix)) + local_file = local_base_dir.joinpath(file_name) + + # Ensure the parent directories exist for the local file + local_file.parent.mkdir(parents=True, exist_ok=True) + + if not local_file.exists(): + + # Download the individual file + s3_client.download_file( + Bucket=os.environ.get("S3_CONTAINER_NAME"), + Key=s3_key, + Filename=str(local_file), + ) + else: + logger.info(f"Did not found any files. Skipping copying of {s3_prefix}.") diff --git a/eurocropsml/acquisition/utils.py b/eurocropsml/acquisition/utils.py index d49742f..b8f789f 100644 --- a/eurocropsml/acquisition/utils.py +++ b/eurocropsml/acquisition/utils.py @@ -237,5 +237,5 @@ def _get_dict_value_by_name( return None -def _load_pkg(file: Path) -> pd.DataFrame: +def _load_pkl(file: Path) -> pd.DataFrame: return pd.read_pickle(file) diff --git a/eurocropsml/dataset/config.py b/eurocropsml/dataset/config.py index fe25e0d..f9c4164 100644 --- a/eurocropsml/dataset/config.py +++ b/eurocropsml/dataset/config.py @@ -30,6 +30,8 @@ class EuroCropsDatasetPreprocessConfig(BaseModel): a couple of classes are relevant. In that case, it massively speeds up the pre- processing. satellite: Preprocess Sentinel-1 or Sentinel-2. + country_list: List of country identifiers to preprocess. + If empty, all countries will be attempted to be preprocessed. bands: If this is None, the default bands stated in the global variables will be used. These are also the ones available in the ready-to-use EuroCropsML dataset. If during your own data acquisition not all bands or different bands were acquired, @@ -49,6 +51,7 @@ class EuroCropsDatasetPreprocessConfig(BaseModel): excl_classes: list[int] = [] keep_classes: list[int] = [] satellite: Literal["S1", "S2"] = "S2" + country_list: list[str] | None = None bands: list[str] | None = None year: int = 2021 diff --git a/eurocropsml/dataset/preprocess.py b/eurocropsml/dataset/preprocess.py index abdd49d..3e14875 100644 --- a/eurocropsml/dataset/preprocess.py +++ b/eurocropsml/dataset/preprocess.py @@ -5,7 +5,7 @@ import shutil import sys from collections import defaultdict -from concurrent.futures import ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor, as_completed from functools import cache, partial from multiprocessing import Pool from pathlib import Path @@ -104,7 +104,7 @@ def _get_lonlats(metadata_dir: Path, country: str) -> dict[int, np.ndarray]: @cache def get_class_ids_to_names(raw_data_dir: Path) -> dict[str, str]: """Get a dictionary mapping between class identifiers and readable names.""" - labels_df: pd.DataFrame = read_metadata(raw_data_dir) + labels_df: pd.DataFrame = read_metadata(raw_data_dir.joinpath("labels")) unique_labels_df = labels_df.drop_duplicates() ids_to_names_dict = unique_labels_df.set_index("EC_hcat_c").to_dict()["EC_hcat_n"] return {str(k): v for k, v in ids_to_names_dict.items()} @@ -266,85 +266,113 @@ def preprocess( month_preprocess_dir.mkdir(exist_ok=True, parents=True) for file_path in month_data_dir.glob("*.parquet"): - country_file: pd.DataFrame = pd.read_parquet(file_path).set_index("parcel_id") - cols = country_file.columns.tolist() - cols = cols[5:] - # filter nan-values - country_file = country_file[~country_file[f"nuts{nuts_level}"].isna()] - points = _get_lonlats( - raw_data_dir.joinpath("geometries", str(preprocess_config.year)), file_path.stem - ) - labels = _get_labels( - raw_data_dir.joinpath("labels", str(preprocess_config.year)), - file_path.stem, - preprocess_config, - ) - - regions = country_file[f"nuts{nuts_level}"].unique() - te = tqdm( - total=len(regions), - desc=f"Processing {file_path.stem}", - ) - for region in regions: - if any( - f.name.startswith(region) - for f in month_preprocess_dir.iterdir() - if f.is_file() - ): - logger.info( - f"There is already existing data for NUTS region {region} for " - f"{month_name}. Skipping pre-processing." - ) - continue - region_data = country_file[country_file[f"nuts{nuts_level}"] == region] - - # remove parcels that do not appear in the labels dictionary as keys - region_data = region_data[region_data.index.isin(labels.keys())] - region_data = region_data[cols] - # removing empty columns - region_data = region_data.dropna(axis=1, how="all") - # removing empty parcels - region_data = region_data.dropna(how="all") - # replacing single empty timesteps - - region_data = region_data.apply( - lambda x, b=len(bands): x.map( - lambda y: np.array([-999] * b) if y is None else y - ) + if ( + preprocess_config.country_list + and file_path.stem not in preprocess_config.country_list + ): + logger.info(f"Skipping {file_path.stem}. Not in country list.") + continue + else: + country_file: pd.DataFrame = pd.read_parquet(file_path) + + cols = country_file.columns.tolist() + cols = cols[5:] + if f"nuts{nuts_level}" in cols: + cols.remove(f"nuts{nuts_level}") + # filter nan-values + country_file = country_file[~country_file[f"nuts{nuts_level}"].isna()] + if "parcel_id" in country_file.columns: + country_file.set_index("parcel_id", inplace=True) + + points = _get_lonlats( + raw_data_dir.joinpath("geometries", str(preprocess_config.year)), + file_path.stem, + ) + labels = _get_labels( + raw_data_dir.joinpath("labels", str(preprocess_config.year)), + file_path.stem, + preprocess_config, ) - with Pool(processes=num_workers) as p: - func = partial( - _save_row, - preprocess_config, - month_preprocess_dir, - labels, - points, - region, - len(bands), + + regions = country_file[f"nuts{nuts_level}"].unique() + te = tqdm( + total=len(regions), + desc=f"Processing {file_path.stem}", + ) + for region in regions: + if any( + f.name.startswith(f"{region}_") + for f in month_preprocess_dir.iterdir() + if f.is_file() + ): + logger.info( + f"There is already existing data for NUTS region {region} for " + f"{month_name}. Skipping pre-processing." + ) + continue + + region_data = country_file[country_file[f"nuts{nuts_level}"] == region] + # remove parcels that do not appear in the labels dictionary as keys + region_data = region_data[region_data.index.isin(labels.keys())] + region_data = region_data[cols] + # removing empty columns + region_data = region_data.dropna(axis=1, how="all") + # removing empty parcels + region_data = region_data.dropna(how="all") + # replacing single empty timesteps + region_data = region_data.apply( + lambda x, b=len(bands): x.map( + lambda y: np.array([-999] * b) if y is None else y + ) ) - process_iter = p.imap(func, region_data.iterrows(), chunksize=1000) - ti = tqdm(total=len(region_data), desc=f"Processing {region}") - _ = [ti.update(n=1) for _ in process_iter] - ti.close() - te.update(n=1) + for row in region_data.iterrows(): + _save_row( + preprocess_config, + month_preprocess_dir, + labels, + points, + region, + len(bands), + row, + ) + with Pool(processes=num_workers) as p: + func = partial( + _save_row, + preprocess_config, + month_preprocess_dir, + labels, + points, + region, + len(bands), + ) + process_iter = p.imap(func, region_data.iterrows(), chunksize=1000) + ti = tqdm(total=len(region_data), desc=f"Processing {region}") + _ = [ti.update(n=1) for _ in process_iter] + ti.close() + + te.update(n=1) te.close() monthly_groups = defaultdict(list) - for folder in tqdm(preprocess_dir.iterdir(), desc="Merging time series..."): + # No tqdm needed for fast file listing + for folder in preprocess_dir.iterdir(): if folder.is_dir(): for npz_file in folder.glob("*.npz"): - monthly_groups[npz_file.name].append(npz_file) + monthly_groups[npz_file.name].append(str(npz_file)) - te = tqdm(total=len(monthly_groups), desc="Merging time series...") + # Use as_completed for simpler, robust progress bar updates with ProcessPoolExecutor(max_workers=num_workers) as executor: - futures = [ + futures = { executor.submit(_merge_npz_files, file_name, file_paths, preprocess_dir) for file_name, file_paths in monthly_groups.items() - ] + } - for _ in futures: - te.update(n=1) + # as_completed yields futures as they complete, simplifying the update loop + for _ in tqdm(as_completed(futures), total=len(futures), desc="Merging time series..."): + # You can access the result or check for exceptions here if needed + # result = future.result() + pass for folder in preprocess_dir.iterdir(): if folder.is_dir(): diff --git a/eurocropsml/settings.py b/eurocropsml/settings.py index 151582d..b85df7e 100644 --- a/eurocropsml/settings.py +++ b/eurocropsml/settings.py @@ -7,6 +7,10 @@ ROOT_DIR = Path(__file__).parents[1] +REGION_NAME = "eu-central-1" # Default fallback region +ENDPOINT_URL = "https://eodata.cloudferro.com" +CONTAINER = "EODATA" + class Settings(BaseSettings): """Global settings.""" diff --git a/requirements/requirements-dev.in b/requirements/requirements-dev.in index 86c431e..6cac1d2 100644 --- a/requirements/requirements-dev.in +++ b/requirements/requirements-dev.in @@ -1,4 +1,4 @@ # Updating the torch version might require changes to the Cuda Version in the Makefile as well -torch==2.2.0 +torch==2.3.1 tox tox-ignore-env-name-mismatch \ No newline at end of file diff --git a/requirements/requirements.in b/requirements/requirements.in index dd516ab..ad0940b 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -9,10 +9,11 @@ pyogrio rasterio requests scikit-learn -torch>=2.0 +torch>=2.3 typer[all] types-requests bs4 selenium webdriver_manager xmlschema +boto3