Datenaufbereitung mit PySpark¶
PySpark ist die Python-Schnittstelle für Apache Spark, ein leistungsfähiges Open-Source-System zur verteilten Verarbeitung großer Datenmengen. Mit PySpark können Sie Daten aufbereiten und analysieren, indem Sie die Funktionen und APIs von Spark in Python-Code nutzen.
PySpark ermöglicht es, Daten aus verschiedenen Quellen wie CSV-Dateien, Datenbanken oder verteilten Dateisystemen einzulesen und in Spark DataFrames zu transformieren. Mit den DataFrame-APIs oder per Spark SQL können Sie dann Datenbereinigung, Filterung, Aggregation und andere Transformationen durchführen, um die Daten für die Analyse vorzubereiten. Die Namen der DataFrame-API-Funktionen sind der gängigen SQL-Terminologie entnommen.
Typische Aufgaben bei der Datenaufbereitung
Arbeitsschritt | SQL | DataFrame API |
---|---|---|
Zusammenfügen von mehreren Tabellen | join | join() |
Auswahl von Spalten | select | select() |
Filtern von Zeilen | where | filter() oder where() |
Gruppieren | group by | groupBy() |
Summieren | sum() | sum() |
Im Folgenden werden diese Funktionen anhand eines einfachen Beispiels vorgestellt. In dem Beispiel werden drei Tabellen eines Stern-Schemas (Faktentabelle, Dimensionstabelle Kunde und Dimensionstabelle Zeit) zusammengeführt und dann ermittelt, welche Kundengruppe in einem bestimmten Quartal den höchsten Umsatz hatte.
Die DataFrame-API ist modularer und kann einfacher gekapselt und getestet werden. SQL führt zu einem kompakteren und leichter lesbaren Code. Außerdem ist SQL-Code portabler, da SQL in verschiedensten Umgebungen ausgeführt werden kann.
Weitere Informationen:
Zunächst importieren wir die notwendigen Bibliotheken und starten wir eine SparkSession.
from pyspark.sql import SparkSession
# Starte eine Spark Session
spark = SparkSession.builder.appName("Datenaufbereitung_PySpark").getOrCreate()
Mit der Funktion createDataFrame() können wir einen DataFrame beispielsweise aus einer Liste von Listen (=Zeilen) erstellen. Dies ist nützlich für kleine Test- oder Demodatensätze, normalerweise würden wir die Daten aus einer Datenquelle auslesen.
# Erstelle die Daten für die Faktentabelle
faktentabelle_data = [
["A001", "20240115", "K4716", 120],
["A002", "20240116", "K4712", 150],
["A003", "20240127", "K4713", 80],
["A004", "20240205", "K4714", 65],
["A005", "20240212", "K4711", 180],
["A006", "20240221", "K4714", 55],
["A007", "20240221", "K4715", 75],
["A008", "20240317", "K4711", 150],
["A009", "20240401", "K4711", 120]
]
faktentabelle_columns = ['auftrag_id', 'zeit_id', 'kunden_id', 'umsatz']
faktentabelle = spark.createDataFrame(faktentabelle_data, faktentabelle_columns)
# Erstelle die Daten für die Kundendimension
kunden_dim_data = [
["K4711", "Loyale"],
["K4712", "Loyale"],
["K4713", "Preisbewusste"],
["K4714", "Preisbewusste"],
["K4715", "Neukunde"],
["K4716", 'Neukunde']
]
kunden_dim_columns = ['kunden_id', 'kundengruppe']
kunden_dim = spark.createDataFrame(kunden_dim_data, kunden_dim_columns)
# Erstelle die Daten für die Zeitdimension
zeit_dim_data = [
["20240115", 1],
["20240116", 1],
["20240127", 1],
["20240205", 1],
["20240212", 1],
["20240221", 1],
["20240317", 1],
["20240401", 1]
]
zeit_dim_columns = ['zeit_id', 'quartal']
zeit_dim = spark.createDataFrame(zeit_dim_data, zeit_dim_columns)
Als erster Schritt der Datenaufbereitung fügen wir die Tabellen zusammen.
# Führe die notwendigen Joins durch, um die Informationen zu vereinen
denormalisierte_tabelle = faktentabelle.join(kunden_dim, on="kunden_id").join(zeit_dim, on="zeit_id")
denormalisierte_tabelle.sort("auftrag_id").show()
+--------+---------+----------+------+-------------+-------+ | zeit_id|kunden_id|auftrag_id|umsatz| kundengruppe|quartal| +--------+---------+----------+------+-------------+-------+ |20240115| K4716| A001| 120| Neukunde| 1| |20240116| K4712| A002| 150| Loyale| 1| |20240127| K4713| A003| 80|Preisbewusste| 1| |20240205| K4714| A004| 65|Preisbewusste| 1| |20240212| K4711| A005| 180| Loyale| 1| |20240221| K4714| A006| 55|Preisbewusste| 1| |20240221| K4715| A007| 75| Neukunde| 1| |20240317| K4711| A008| 150| Loyale| 1| |20240401| K4711| A009| 120| Loyale| 1| +--------+---------+----------+------+-------------+-------+
Im zweiten Schritt selektieren wir die notwendigen Spalten und Filtern die Zeilen, um nur die Einträge aus dem ersten Quartal zu analysieren.
# Wähle nur die benötigten Spalten aus und filtere nach dem gewünschten Quartal
usecase_tabelle = denormalisierte_tabelle.select("kundengruppe", "umsatz", "quartal").filter("quartal = 1")
usecase_tabelle.sort("kundengruppe").show()
+-------------+------+-------+ | kundengruppe|umsatz|quartal| +-------------+------+-------+ | Loyale| 180| 1| | Loyale| 120| 1| | Loyale| 150| 1| | Loyale| 150| 1| | Neukunde| 120| 1| | Neukunde| 75| 1| |Preisbewusste| 65| 1| |Preisbewusste| 80| 1| |Preisbewusste| 55| 1| +-------------+------+-------+
Mit Hilfe von groupBy() und sum() summeieren wir die Umsätze nach Kundengruppe. Das Ergebnis liefert eine Spalte mit dem Titel "sum(umsatz)". Wir nutzen withColumnRenamed um diese aggregierte Spalte in "gesamtumsatz" umzubenennen.
# Aggregierung: Berechnung Umsatz pro Kundengruppe
aggregierte_tabelle = usecase_tabelle.groupBy("kundengruppe").sum("umsatz")
aggregierte_tabelle = aggregierte_tabelle.withColumnRenamed('sum(umsatz)', 'gesamtumsatz')
aggregierte_tabelle.show()
+-------------+------------+ | kundengruppe|gesamtumsatz| +-------------+------------+ | Loyale| 600| |Preisbewusste| 200| | Neukunde| 195| +-------------+------------+
Um die Gruppe mit dem höchsten Umsatz zu ermitteln, sortieren wir die Tabelle mit dem Gesamtumsatz pro Kundengruppe absteigend nach dem Gesamtumsatz. In diesem Beispel wäre dies nicht unbedingt notwendig, aber wenn es mehr Kundengruppen gibt oder wir das Ergebnis extrahieren wollen, dann ist es nützlich das das Ergebnis jetzt in der ersten Zeile steht.
# Ergebnis: absteigend sortiert
# in diesem Beispiel eigentlich nicht mehr nötig
sortierte_tabelle = aggregierte_tabelle.orderBy(aggregierte_tabelle["gesamtumsatz"].desc())
sortierte_tabelle.show()
+-------------+------------+ | kundengruppe|gesamtumsatz| +-------------+------------+ | Loyale| 600| |Preisbewusste| 200| | Neukunde| 195| +-------------+------------+
Wenn wir mehrere Aggregierungen gleichzeitig vornehmen wollen, dann können wir mit der agg()-Funktion arbeiten. Um Namenskonflikte zu vermeiden, ist es best practice die Aggregierungsfunktionen, die wir in agg() verwenden, mit dem Verweis auf das Modul functions mit dem Alias "F" aufzurufen. Mit Hilfe von alias() können wir das Ergebnis gleich umbenennen.
# Alternative für mehrere Aggregierungen
from pyspark.sql import functions as F
aggregierte_tabelle_2 = usecase_tabelle.groupBy('kundengruppe').agg(F.sum('umsatz').alias('gesamtumsatz'), F.avg('umsatz').alias('durchschnittsumsatz'))
aggregierte_tabelle_2.show()
+-------------+------------+-------------------+ | kundengruppe|gesamtumsatz|durchschnittsumsatz| +-------------+------------+-------------------+ | Loyale| 600| 150.0| |Preisbewusste| 200| 66.66666666666667| | Neukunde| 195| 97.5| +-------------+------------+-------------------+
Alternativ können wir die Abfrage auch als eine einzige SQL-Abfrage schreiben. Dazu müssen wir die DataFrames erst mit Hilfe von createOrRplaceTempView() der sql()-Funktion als Tabellen zugänglich machen.
# Mit Spark SQL
# Erstelle oder registriere die DataFrames als temporäre Tabellen
faktentabelle.createOrReplaceTempView("faktentabelle")
kunden_dim.createOrReplaceTempView("kundengruppen_dim")
zeit_dim.createOrReplaceTempView("zeit_dim")
# Führe die gewünschte Analyse mit Spark SQL durch
ergebnis = spark.sql(
'''SELECT k.kundengruppe, SUM(f.umsatz) AS gesamtumsatz
FROM faktentabelle f
JOIN kundengruppen_dim k ON f.kunden_id = k.kunden_id
JOIN zeit_dim z ON f.zeit_id = z.zeit_id
WHERE z.quartal = 1
GROUP BY k.kundengruppe
ORDER BY gesamtumsatz DESC'''
)
# Zeige das Ergebnis
ergebnis.show()
+-------------+------------+ | kundengruppe|gesamtumsatz| +-------------+------------+ | Loyale| 600| |Preisbewusste| 200| | Neukunde| 195| +-------------+------------+