Datenaufbereitung mit PySpark¶

PySpark ist die Python-Schnittstelle für Apache Spark, ein leistungsfähiges Open-Source-System zur verteilten Verarbeitung großer Datenmengen. Mit PySpark können Sie Daten aufbereiten und analysieren, indem Sie die Funktionen und APIs von Spark in Python-Code nutzen.

PySpark ermöglicht es, Daten aus verschiedenen Quellen wie CSV-Dateien, Datenbanken oder verteilten Dateisystemen einzulesen und in Spark DataFrames zu transformieren. Mit den DataFrame-APIs oder per Spark SQL können Sie dann Datenbereinigung, Filterung, Aggregation und andere Transformationen durchführen, um die Daten für die Analyse vorzubereiten. Die Namen der DataFrame-API-Funktionen sind der gängigen SQL-Terminologie entnommen.

Typische Aufgaben bei der Datenaufbereitung

Arbeitsschritt SQL DataFrame API
Zusammenfügen von mehreren Tabellen join join()
Auswahl von Spalten select select()
Filtern von Zeilen where filter() oder where()
Gruppieren group by groupBy()
Summieren sum() sum()

Im Folgenden werden diese Funktionen anhand eines einfachen Beispiels vorgestellt. In dem Beispiel werden drei Tabellen eines Stern-Schemas (Faktentabelle, Dimensionstabelle Kunde und Dimensionstabelle Zeit) zusammengeführt und dann ermittelt, welche Kundengruppe in einem bestimmten Quartal den höchsten Umsatz hatte.

Die DataFrame-API ist modularer und kann einfacher gekapselt und getestet werden. SQL führt zu einem kompakteren und leichter lesbaren Code. Außerdem ist SQL-Code portabler, da SQL in verschiedensten Umgebungen ausgeführt werden kann.

Weitere Informationen:

DataFrames vs. Spark SQL

Tutorial: Spark by Examples - DataFrame API

Tutorial: Spark by Examples - SQL

Zunächst importieren wir die notwendigen Bibliotheken und starten wir eine SparkSession.

In [3]:
from pyspark.sql import SparkSession
In [4]:
# Starte eine Spark Session
spark = SparkSession.builder.appName("Datenaufbereitung_PySpark").getOrCreate()

Mit der Funktion createDataFrame() können wir einen DataFrame beispielsweise aus einer Liste von Listen (=Zeilen) erstellen. Dies ist nützlich für kleine Test- oder Demodatensätze, normalerweise würden wir die Daten aus einer Datenquelle auslesen.

In [5]:
# Erstelle die Daten für die Faktentabelle
faktentabelle_data = [
    ["A001", "20240115", "K4716", 120],
    ["A002", "20240116", "K4712", 150],
    ["A003", "20240127", "K4713", 80],
    ["A004", "20240205", "K4714", 65],
    ["A005", "20240212", "K4711", 180],
    ["A006", "20240221", "K4714", 55],
    ["A007", "20240221", "K4715", 75],
    ["A008", "20240317", "K4711", 150],
    ["A009", "20240401", "K4711", 120]
]
faktentabelle_columns = ['auftrag_id', 'zeit_id', 'kunden_id', 'umsatz']
faktentabelle = spark.createDataFrame(faktentabelle_data, faktentabelle_columns)
In [6]:
# Erstelle die Daten für die Kundendimension
kunden_dim_data = [
    ["K4711", "Loyale"],
    ["K4712", "Loyale"],
    ["K4713", "Preisbewusste"],
    ["K4714", "Preisbewusste"],
    ["K4715", "Neukunde"],
    ["K4716", 'Neukunde']

]
kunden_dim_columns = ['kunden_id', 'kundengruppe']
kunden_dim = spark.createDataFrame(kunden_dim_data, kunden_dim_columns)
In [7]:
# Erstelle die Daten für die Zeitdimension
zeit_dim_data = [
    ["20240115", 1],
    ["20240116", 1],
    ["20240127", 1],
    ["20240205", 1],
    ["20240212", 1],
    ["20240221", 1],
    ["20240317", 1],
    ["20240401", 1]

]
zeit_dim_columns = ['zeit_id', 'quartal']
zeit_dim = spark.createDataFrame(zeit_dim_data, zeit_dim_columns)

Als erster Schritt der Datenaufbereitung fügen wir die Tabellen zusammen.

In [8]:
# Führe die notwendigen Joins durch, um die Informationen zu vereinen
denormalisierte_tabelle = faktentabelle.join(kunden_dim, on="kunden_id").join(zeit_dim, on="zeit_id")
denormalisierte_tabelle.sort("auftrag_id").show()
+--------+---------+----------+------+-------------+-------+
| zeit_id|kunden_id|auftrag_id|umsatz| kundengruppe|quartal|
+--------+---------+----------+------+-------------+-------+
|20240115|    K4716|      A001|   120|     Neukunde|      1|
|20240116|    K4712|      A002|   150|       Loyale|      1|
|20240127|    K4713|      A003|    80|Preisbewusste|      1|
|20240205|    K4714|      A004|    65|Preisbewusste|      1|
|20240212|    K4711|      A005|   180|       Loyale|      1|
|20240221|    K4714|      A006|    55|Preisbewusste|      1|
|20240221|    K4715|      A007|    75|     Neukunde|      1|
|20240317|    K4711|      A008|   150|       Loyale|      1|
|20240401|    K4711|      A009|   120|       Loyale|      1|
+--------+---------+----------+------+-------------+-------+

Im zweiten Schritt selektieren wir die notwendigen Spalten und Filtern die Zeilen, um nur die Einträge aus dem ersten Quartal zu analysieren.

In [9]:
# Wähle nur die benötigten Spalten aus und filtere nach dem gewünschten Quartal
usecase_tabelle = denormalisierte_tabelle.select("kundengruppe", "umsatz", "quartal").filter("quartal = 1")
usecase_tabelle.sort("kundengruppe").show()
+-------------+------+-------+
| kundengruppe|umsatz|quartal|
+-------------+------+-------+
|       Loyale|   180|      1|
|       Loyale|   120|      1|
|       Loyale|   150|      1|
|       Loyale|   150|      1|
|     Neukunde|   120|      1|
|     Neukunde|    75|      1|
|Preisbewusste|    65|      1|
|Preisbewusste|    80|      1|
|Preisbewusste|    55|      1|
+-------------+------+-------+

Mit Hilfe von groupBy() und sum() summeieren wir die Umsätze nach Kundengruppe. Das Ergebnis liefert eine Spalte mit dem Titel "sum(umsatz)". Wir nutzen withColumnRenamed um diese aggregierte Spalte in "gesamtumsatz" umzubenennen.

In [10]:
# Aggregierung: Berechnung Umsatz pro Kundengruppe
aggregierte_tabelle = usecase_tabelle.groupBy("kundengruppe").sum("umsatz")
aggregierte_tabelle = aggregierte_tabelle.withColumnRenamed('sum(umsatz)', 'gesamtumsatz')
aggregierte_tabelle.show()
+-------------+------------+
| kundengruppe|gesamtumsatz|
+-------------+------------+
|       Loyale|         600|
|Preisbewusste|         200|
|     Neukunde|         195|
+-------------+------------+

Um die Gruppe mit dem höchsten Umsatz zu ermitteln, sortieren wir die Tabelle mit dem Gesamtumsatz pro Kundengruppe absteigend nach dem Gesamtumsatz. In diesem Beispel wäre dies nicht unbedingt notwendig, aber wenn es mehr Kundengruppen gibt oder wir das Ergebnis extrahieren wollen, dann ist es nützlich das das Ergebnis jetzt in der ersten Zeile steht.

In [11]:
# Ergebnis: absteigend sortiert 
# in diesem Beispiel eigentlich nicht mehr nötig
sortierte_tabelle = aggregierte_tabelle.orderBy(aggregierte_tabelle["gesamtumsatz"].desc())
sortierte_tabelle.show()
+-------------+------------+
| kundengruppe|gesamtumsatz|
+-------------+------------+
|       Loyale|         600|
|Preisbewusste|         200|
|     Neukunde|         195|
+-------------+------------+

Wenn wir mehrere Aggregierungen gleichzeitig vornehmen wollen, dann können wir mit der agg()-Funktion arbeiten. Um Namenskonflikte zu vermeiden, ist es best practice die Aggregierungsfunktionen, die wir in agg() verwenden, mit dem Verweis auf das Modul functions mit dem Alias "F" aufzurufen. Mit Hilfe von alias() können wir das Ergebnis gleich umbenennen.

In [12]:
# Alternative für mehrere Aggregierungen
from pyspark.sql import functions as F
aggregierte_tabelle_2 = usecase_tabelle.groupBy('kundengruppe').agg(F.sum('umsatz').alias('gesamtumsatz'), F.avg('umsatz').alias('durchschnittsumsatz'))
aggregierte_tabelle_2.show()
+-------------+------------+-------------------+
| kundengruppe|gesamtumsatz|durchschnittsumsatz|
+-------------+------------+-------------------+
|       Loyale|         600|              150.0|
|Preisbewusste|         200|  66.66666666666667|
|     Neukunde|         195|               97.5|
+-------------+------------+-------------------+

Alternativ können wir die Abfrage auch als eine einzige SQL-Abfrage schreiben. Dazu müssen wir die DataFrames erst mit Hilfe von createOrRplaceTempView() der sql()-Funktion als Tabellen zugänglich machen.

In [13]:
# Mit Spark SQL
# Erstelle oder registriere die DataFrames als temporäre Tabellen
faktentabelle.createOrReplaceTempView("faktentabelle")
kunden_dim.createOrReplaceTempView("kundengruppen_dim")
zeit_dim.createOrReplaceTempView("zeit_dim")

# Führe die gewünschte Analyse mit Spark SQL durch
ergebnis = spark.sql(
    '''SELECT k.kundengruppe, SUM(f.umsatz) AS gesamtumsatz
       FROM faktentabelle f
       JOIN kundengruppen_dim k ON f.kunden_id = k.kunden_id
       JOIN zeit_dim z ON f.zeit_id = z.zeit_id
       WHERE z.quartal = 1
       GROUP BY k.kundengruppe
       ORDER BY gesamtumsatz DESC'''
)

# Zeige das Ergebnis
ergebnis.show()
+-------------+------------+
| kundengruppe|gesamtumsatz|
+-------------+------------+
|       Loyale|         600|
|Preisbewusste|         200|
|     Neukunde|         195|
+-------------+------------+