From b71216b71f686c5424070bfdc1c1dc5308425e87 Mon Sep 17 00:00:00 2001 From: vineela03 Date: Wed, 14 Dec 2022 06:17:19 +0100 Subject: [PATCH] CDP-1507 : Modified spark configs to support delta lake, added delta read write utils package --- Makefile | 2 +- requirements.txt | 1 + src/pyspark_core_utils/apps.py | 3 + src/pyspark_core_utils/delta_utils.py | 86 +++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 src/pyspark_core_utils/delta_utils.py diff --git a/Makefile b/Makefile index 1f35a2c..b31382d 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ setup-environment: $(PYTHON) -m virtualenv env . env/bin/activate; \ pip3 install -r requirements.txt; \ - pip3 install pyspark==3.1.2 + pip3 install pyspark==3.2.0 test: setup-environment . env/bin/activate; \ diff --git a/requirements.txt b/requirements.txt index 0a8e336..b9df301 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ testfixtures==6.18.3 PyYAML==6.0 importlib-resources==5.4.0 dotmap==1.3.25 +delta-spark==2.0.0 diff --git a/src/pyspark_core_utils/apps.py b/src/pyspark_core_utils/apps.py index 7ac4d75..d965573 100644 --- a/src/pyspark_core_utils/apps.py +++ b/src/pyspark_core_utils/apps.py @@ -36,6 +36,9 @@ def _init_spark(self): return SparkSession \ .builder \ + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \ + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \ + .config("spark.sql.warehouse.dir","s3://is24-data-hive-warehouse/") \ .config(conf=spark_conf) \ .enableHiveSupport() \ .getOrCreate() diff --git a/src/pyspark_core_utils/delta_utils.py b/src/pyspark_core_utils/delta_utils.py new file mode 100644 index 0000000..9122eff --- /dev/null +++ b/src/pyspark_core_utils/delta_utils.py @@ -0,0 +1,86 @@ +from delta.tables import DeltaTable +import re + + +def write_partitioned_data_delta(self, dataframe, partition_name, partition_dates_to_override, write_mode, + target_base_path): + return dataframe \ + .write.partitionBy(partition_name) \ + .format("delta") \ + .option("mergeSchema", "true") \ + .option("__partition_columns", partition_name) \ + .option("replaceWhere", "{} in ({})".format(partition_name, ', '.join( + map(lambda x: "'{}'".format(x), partition_dates_to_override)))) \ + .mode(write_mode) \ + .save(target_base_path) + + +def write_nonpartitioned_data_delta(self, dataframe, write_mode, target_base_path): + return dataframe \ + .write.format("delta") \ + .option("mergeSchema", "true") \ + .mode(write_mode) \ + .save(target_base_path) + + +def compact_delta_table_partitions(self, sparkSession, base_path, partition_name, dates, num_files): + return sparkSession.read \ + .format("delta") \ + .load(base_path) \ + .where(f"{partition_name} in (', '.join(map(lambda x : "'{}'".format(x), dates)))") \ + .repartition(num_files) \ + .write \ + .option("dataChange", "false") \ + .format("delta") \ + .mode("overwrite") \ + .option("replaceWhere", "{} in ({})".format(partition_name, ', '.join(map(lambda x: "'{}'".format(x), dates)))) \ + .save(base_path) + + +def generate_delta_table(self, sparkSession, schema_name, table_name, s3location): + self.spark.sql("create database if not exists {}".format(schema_name)) + qualified_table_name = f"""{schema_name}.{table_name}""" + DeltaTable.createIfNotExists(sparkSession) \ + .tableName(qualified_table_name) \ + .location(s3location) \ + .execute() + print(f"Delta table {qualified_table_name} generated") + + +def extract_delta_info_from_path(self, paths): + path = paths[0] + path_reg_exp = """(.*)/(.*)=(.*)""" + try: + match_pattern_to_path = re.match(path_reg_exp, path) + except: + raise Exception("Can not read {}: base path can not be extracted".format(paths.mkString(","))) + + base_path = match_pattern_to_path.group(1) + partition_name = match_pattern_to_path.group(2) + dates = map(lambda path: re.match(path_reg_exp, path).group(3), paths) + print(base_path) + print(partition_name) + print(dates) + return (base_path, partition_name, dates) + + +def read_delta_from_s3(self, sparkSession, paths): + (base_path, partition_name, dates) = extract_delta_info_from_path(self, paths) + df = sparkSession.read \ + .format("delta") \ + .load(base_path) \ + .where("{} in ({})".format(partition_name, ', '.join(map(lambda x: "'{}'".format(x), dates)))) + print(df.count()) + return df + + +def delta_read_from_basepath(self, sparkSession, base_path): + return sparkSession.read \ + .format("delta") \ + .load(base_path) + + +def read_delta_table(self, sparkSession, schema_name, table_name, partition_name, partition_dates): + qualified_table_name = f"""{schema_name}.{table_name}""" + return sparkSession.read.table(qualified_table_name) \ + .where("{} in ({})".format(partition_name, ', '.join(map(lambda x: "'{}'".format(x), partition_dates))))