Have you ever wondered how to perform rolling averages in PySpark? This snippet helps you through the process

from typing import Callable, Any, List
from pyspark.sql import Column
from pyspark.sql import DataFrame as SparkDataFrame
def create_rolling_feature(
  df: SparkDataFrame,
  id_cols:List[str]=["GROUP1", "GROUP2"],
  value_col:List[str]="VALUE",
  time_col:str="DATE",
  window_size:int=3,
  agg_func: Callable[[Column,],Any] = F.mean
) -> SparkDataFrame:
  """
  Creates a moving window average
  df: SparkDataFrame
    Spark dataframe to work on
  id_cols: List[str]
    The list of partitionBy columns over which to group the rolling function
  value_col: List[str]
    The name of the columns we want to compute rolling operations over
  time_col: str
    Name of the column representing time. Here we assume we have datetime columns with the possibility of casting to long.
  window_size: int
    The number of rows to consider in the rolling aggregation, by default 3 means that the moving operations is done on the aggregation function over the [current-3, current-2, current-1, current] rows.
  agg_func:
    A PySpark aggregation function. Can be any function that takes a column and returns a scalar, for example `F.mean`, `F.min`, `F.max`
  """
  rolling_col = f"ROLLING_{agg_func.__name__.upper()}_{value_col}_W{window_size}"
  window = Window.partitionBy(*id_cols).orderBy(time_col)
  return (
    df
    .withColumn(
      rolling_col,
        agg_func(F.col(value_col)).over(
          Window.partitionBy(*id_cols).orderBy(F.col(time_col).cast("timestamp").cast("long").asc()).rowsBetween(-window_size,0)
      )
    )
  )

create_rolling_feature(X,agg_func=F.max).display()