Spark: Finding Cumulative Sum
2023, April, 10
SparkScalaPython
Let's see how to find cumulative sum of a column in Spark using Scala & PySpark with an example.
Dummy data
product | year | sale |
---|---|---|
A | 2018 | 10 |
B | 2018 | 8 |
A | 2019 | 13 |
B | 2019 | 14 |
A | 2020 | 12 |
B | 2020 | 11 |
A | 2021 | 15 |
B | 2021 | 13 |
A | 2022 | 17 |
B | 2022 | 18 |
A | 2023 | 15 |
B | 2023 | 16 |
Scala
Import dependencies
import org.apache.spark.sql.expressions.Window
Create a Scala class
case class Sales(product: String, year: Int, sale: Int)
Dummy data rows
val df = Seq( new Sales("A", 2018, 10), new Sales("B", 2018, 8), new Sales("A", 2019, 13), new Sales("B", 2019, 14), new Sales("A", 2020, 12), new Sales("B", 2020, 11), new Sales("A", 2021, 15), new Sales("B", 2021, 13), new Sales("A", 2022, 17), new Sales("B", 2022, 18), new Sales("A", 2023, 15), new Sales("B", 2023, 16)).toDF
df .withColumn("cumulativeSales", sum("sale") .over( Window.partitionBy("product").orderBy("year") ) ) .select("year", "cumulativeSales") .groupBy("year").agg(sum("cumulativeSales").alias("Sales")) .orderBy("year")
PySpark
Import dependencies
from pyspark.sql.types import StructType, StructField, StringType, IntegerTypefrom pyspark.sql import Windowfrom pyspark.sql import functions as F
Create a schema
schema = StructType([ \ StructField("product",StringType(),True), \ StructField("year",IntegerType(),True), \ StructField("sale",IntegerType(),True) \ ])
Dummy data rows
data = [("A", 2018,10), ("B", 2018,8), ("A", 2019,13), ("B", 2019,14), ("A", 2020,12), ("B", 2020,11), ("A", 2021,15), ("B", 2021,13), ("A", 2022,17), ("B", 2022,18), ("A", 2023,15), ("B", 2023,16)]
Create DataFrame
df = spark.createDataFrame(data,schema)
df2 = df \ .withColumn('cumulative_sales', F.sum('sale').over(Window.partitionBy('product').orderBy('year'))) \ .select('year', 'cumulative_sales') \ .groupby('year') \ .agg(F.sum('cumulative_sales').alias("sales")) \ .sort('year')df2.display()