Changepoint detection on huge grouped dataframes with ruptures and PySpark
Here we describe a way to perform scalable changepoint detection on grouped time series data by using PySpark and the rupture library.
from typing import List
import numpy as np
import pandas as pd
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.types import *
import pyspark.sql.functions as F
import ruptures as rpt
def changepoint_detection(
  df: SparkDataFrame,
  time_col: str,
  value_col: str,
  group_cols: List[str],
  breakpoint_col:str="BREAKPOINTS",
  breakpoint_default:int=0,
  breakpoint_active:int=1,
  kernel_model:str="rbf",
  penalty:float=0.01,
):
  """
  Performs a grouped changepoint detection on individual time series, denoted by the pair (time_col, value_col)
  
  Runs the rupture off line changepoint detection algorithm
  df: SparkDatFrame
    The dataframe to work on
  time_col: str
    The column representing the datetime
  value_col: str
    The name of the series to segment
  group_cols: List[str]
    The grouping variables on to which perform rupture segmentation separately
  breakpoint_default: int (default 0)
    Value when there is not breakpoint
  breakpoint_active: int (default 1)
    Value when there is a breakpoint
  kernel_model: str
    Rupture kernel for the internal segmentation model
  penalty: float
    Regularization penalty, the larger the penalty the less the number of breakpoints
  """
  schema = df.schema
  
  new_schema = StructType(
    [field for field in schema ] + [StructField(name=breakpoint_col, dataType=IntegerType(), nullable=False)]
  )
    
  @F.pandas_udf(returnType=new_schema, functionType=F.PandasUDFType.GROUPED_MAP)
  def changepoint_algorithm(pandas_dataframe: pd.DataFrame):
    A = pandas_dataframe[[time_col, value_col]].sort_values(by=time_col)
    a,b = A[time_col], A[value_col]
    y = pd.Series(data=breakpoint_default, index=a.index)
    
    y.iloc[
      np.array(
        rpt.Pelt(model=kernel_model).fit(
          A[b.name].astype(float).values.reshape(-1,1)
      )
      .predict(pen=penalty) # penalty factor
    ) - 1 # because breakpoint indices are 1 to N
    ] = breakpoint_active
    return pandas_dataframe.assign(**{breakpoint_col: y})
    
  return (
    df
    .groupBy(*group_cols)
    .apply(changepoint_algorithm)
  )
# Example usage of the function, here we order by client and time to check for the results correctness
(
  changepoint_detection(
    df=df_bonifici_mensile,
    time_col="DATE",
    value_col="MEAN",
    group_cols=["GROUP"],
    breakpoint_active=500,
    kernel_model="l2",
    penalty=5E5
  )
  .orderBy("GROUP", "DATE")
  .display()
)