Code
import os
import sys

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
Code
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

conf = (
    SparkConf()
        .setAppName("Spark SQL Course")
)

sc = SparkContext(conf=conf)  # no need for Spark 3...

spark = (
    SparkSession
        .builder
        .appName("Spark SQL Course")
        .getOrCreate()
)
WARNING: Using incubator modules: jdk.incubator.vector
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/02/08 11:21:30 WARN Utils: Your hostname, boucheron-Precision-5480, resolves to a loopback address: 127.0.1.1; using 192.168.0.36 instead (on interface wlp0s20f3)
26/02/08 11:21:30 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/02/08 11:21:31 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Code
from pyspark.sql import Row

row1 = Row(name="John", age=21)
row2 = Row(name="James", age=32)
row3 = Row(name="Jane", age=18)
row1['name']
'John'
Code
df = spark.createDataFrame([row1, row2, row3])
Code
df.printSchema()
root
 |-- name: string (nullable = true)
 |-- age: long (nullable = true)

How does printSchema compare with Pandas info()?

Code
df.show()
[Stage 0:>                                                          (0 + 1) / 1]                                                                                
+-----+---+
| name|age|
+-----+---+
| John| 21|
|James| 32|
| Jane| 18|
+-----+---+

How does show compare with Pandas head()

Code
print(df.rdd.toDebugString().decode("utf-8"))
(20) MapPartitionsRDD[10] at javaToPython at NativeMethodAccessorImpl.java:0 []
 |   MapPartitionsRDD[9] at javaToPython at NativeMethodAccessorImpl.java:0 []
 |   SQLExecutionRDD[8] at javaToPython at NativeMethodAccessorImpl.java:0 []
 |   MapPartitionsRDD[7] at javaToPython at NativeMethodAccessorImpl.java:0 []
 |   MapPartitionsRDD[4] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0 []
 |   MapPartitionsRDD[3] at map at SerDeUtil.scala:71 []
 |   MapPartitionsRDD[2] at mapPartitions at SerDeUtil.scala:119 []
 |   PythonRDD[1] at RDD at PythonRDD.scala:58 []
 |   ParallelCollectionRDD[0] at readRDDFromFile at PythonRDD.scala:299 []
Code
df.rdd.getNumPartitions()
20

Creating dataframes

Code
rows = [
    Row(name="John", age=21, gender="male"),
    Row(name="James", age=25, gender="female"),
    Row(name="Albert", age=46, gender="male")
]

df = spark.createDataFrame(rows)
Code
df.show()
+------+---+------+
|  name|age|gender|
+------+---+------+
|  John| 21|  male|
| James| 25|female|
|Albert| 46|  male|
+------+---+------+
Code
help(Row)
Help on class Row in module pyspark.sql.types:

class Row(builtins.tuple)
 |  Row(*args: Optional[str], **kwargs: Optional[Any]) -> 'Row'
 |
 |  A row in :class:`DataFrame`.
 |  The fields in it can be accessed:
 |
 |  * like attributes (``row.key``)
 |  * like dictionary values (``row[key]``)
 |
 |  ``key in row`` will search through row keys.
 |
 |  Row can be used to create a row object by using named arguments.
 |  It is not allowed to omit a named argument to represent that the value is
 |  None or missing. This should be explicitly set to None in this case.
 |
 |  .. versionchanged:: 3.0.0
 |      Rows created from named arguments no longer have
 |      field names sorted alphabetically and will be ordered in the position as
 |      entered.
 |
 |  Examples
 |  --------
 |  >>> from pyspark.sql import Row
 |  >>> row = Row(name="Alice", age=11)
 |  >>> row
 |  Row(name='Alice', age=11)
 |  >>> row['name'], row['age']
 |  ('Alice', 11)
 |  >>> row.name, row.age
 |  ('Alice', 11)
 |  >>> 'name' in row
 |  True
 |  >>> 'wrong_key' in row
 |  False
 |
 |  Row also can be used to create another Row like class, then it
 |  could be used to create Row objects, such as
 |
 |  >>> Person = Row("name", "age")
 |  >>> Person
 |  <Row('name', 'age')>
 |  >>> 'name' in Person
 |  True
 |  >>> 'wrong_key' in Person
 |  False
 |  >>> Person("Alice", 11)
 |  Row(name='Alice', age=11)
 |
 |  This form can also be used to create rows as tuple values, i.e. with unnamed
 |  fields.
 |
 |  >>> row1 = Row("Alice", 11)
 |  >>> row2 = Row(name="Alice", age=11)
 |  >>> row1 == row2
 |  True
 |
 |  Method resolution order:
 |      Row
 |      builtins.tuple
 |      builtins.object
 |
 |  Methods defined here:
 |
 |  __call__(self, *args: Any) -> 'Row'
 |      create new Row object
 |
 |  __contains__(self, item: Any) -> bool
 |      Return bool(key in self).
 |
 |  __getattr__(self, item: str) -> Any
 |
 |  __getitem__(self, item: Any) -> Any
 |      Return self[key].
 |
 |  __reduce__(self) -> Union[str, Tuple[Any, ...]]
 |      Returns a tuple so Python knows how to pickle Row.
 |
 |  __repr__(self) -> str
 |      Printable representation of Row used in Python REPL.
 |
 |  __setattr__(self, key: Any, value: Any) -> None
 |      Implement setattr(self, name, value).
 |
 |  asDict(self, recursive: bool = False) -> Dict[str, Any]
 |      Return as a dict
 |
 |      Parameters
 |      ----------
 |      recursive : bool, optional
 |          turns the nested Rows to dict (default: False).
 |
 |      Notes
 |      -----
 |      If a row contains duplicate field names, e.g., the rows of a join
 |      between two :class:`DataFrame` that both have the fields of same names,
 |      one of the duplicate fields will be selected by ``asDict``. ``__getitem__``
 |      will also return one of the duplicate fields, however returned value might
 |      be different to ``asDict``.
 |
 |      Examples
 |      --------
 |      >>> from pyspark.sql import Row
 |      >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
 |      True
 |      >>> row = Row(key=1, value=Row(name='a', age=2))
 |      >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)}
 |      True
 |      >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
 |      True
 |
 |  ----------------------------------------------------------------------
 |  Static methods defined here:
 |
 |  __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> 'Row'
 |      Create and return a new object.  See help(type) for accurate signature.
 |
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |
 |  __dict__
 |      dictionary for instance variables
 |
 |  ----------------------------------------------------------------------
 |  Methods inherited from builtins.tuple:
 |
 |  __add__(self, value, /)
 |      Return self+value.
 |
 |  __eq__(self, value, /)
 |      Return self==value.
 |
 |  __ge__(self, value, /)
 |      Return self>=value.
 |
 |  __getattribute__(self, name, /)
 |      Return getattr(self, name).
 |
 |  __getnewargs__(self, /)
 |
 |  __gt__(self, value, /)
 |      Return self>value.
 |
 |  __hash__(self, /)
 |      Return hash(self).
 |
 |  __iter__(self, /)
 |      Implement iter(self).
 |
 |  __le__(self, value, /)
 |      Return self<=value.
 |
 |  __len__(self, /)
 |      Return len(self).
 |
 |  __lt__(self, value, /)
 |      Return self<value.
 |
 |  __mul__(self, value, /)
 |      Return self*value.
 |
 |  __ne__(self, value, /)
 |      Return self!=value.
 |
 |  __rmul__(self, value, /)
 |      Return value*self.
 |
 |  count(self, value, /)
 |      Return number of occurrences of value.
 |
 |  index(self, value, start=0, stop=9223372036854775807, /)
 |      Return first index of value.
 |
 |      Raises ValueError if the value is not present.
 |
 |  ----------------------------------------------------------------------
 |  Class methods inherited from builtins.tuple:
 |
 |  __class_getitem__(...)
 |      See PEP 585
Code
column_names = ["name", "age", "gender"]
rows = [
    ["John", 21, "male"],
    ["James", 25, "female"],
    ["Albert", 46, "male"]
]

df = spark.createDataFrame(
    rows, 
    column_names
)

df.show()
+------+---+------+
|  name|age|gender|
+------+---+------+
|  John| 21|  male|
| James| 25|female|
|Albert| 46|  male|
+------+---+------+
Code
df.printSchema()
root
 |-- name: string (nullable = true)
 |-- age: long (nullable = true)
 |-- gender: string (nullable = true)
Code
# sc = SparkContext(conf=conf)  # no need for Spark 3...

column_names = ["name", "age", "gender"]
rdd = sc.parallelize([
    ("John", 21, "male"),
    ("James", 25, "female"),
    ("Albert", 46, "male")
])
df = spark.createDataFrame(rdd, column_names)
df.show()
+------+---+------+
|  name|age|gender|
+------+---+------+
|  John| 21|  male|
| James| 25|female|
|Albert| 46|  male|
+------+---+------+

Schema

There is special type schemata. A object of class StructType is made of a list of objects of type StructField.

Code
df.schema
StructType([StructField('name', StringType(), True), StructField('age', LongType(), True), StructField('gender', StringType(), True)])

How does df.schema relate to df.printSchema()? Where would you use the outputs of df.schema and df.printSchema().

Code
type(df.schema)
pyspark.sql.types.StructType

A object of type StructField has a name like gender, a PySpark type like StringType(), an d a boolean parameter.

What does the boolean parameter stand for?

Code
from pyspark.sql.types import *

schema = StructType(
    [
        StructField("name", StringType(), False),
        StructField("age", IntegerType(), True),
        StructField("gender", StringType(), True)
    ]
)

rows = [("Jane", 21, "female")]
df = spark.createDataFrame(rows, schema)
df.printSchema()
df.show()
root
 |-- name: string (nullable = false)
 |-- age: integer (nullable = true)
 |-- gender: string (nullable = true)

+----+---+------+
|name|age|gender|
+----+---+------+
|Jane| 21|female|
+----+---+------+

Queries (single table \(σ\), \(π\))

PySpark offers two ways to query a datafrane:

  • An ad hoc API with methods for the DataFrame class.
  • The possibility to post SQL queries (provided a temporary view has been created).
Code
column_names = ["name", "age", "gender"]
rows = [
    ["John", 21, "male"],
    ["Jane", 25, "female"]
]
# 
df = spark.createDataFrame(rows, column_names)

# Create a temporary view from the DataFrame
df.createOrReplaceTempView("new_view")

# Apply the query
query = """
    SELECT 
        name, age 
    FROM 
        new_view 
    WHERE 
        gender='male'
"""

men_df = spark.sql(query)
men_df.show()
+----+---+
|name|age|
+----+---+
|John| 21|
+----+---+
ImportantNew! (with Spark 4)

We can now write SQL queries the way we have been writing in the tidyverse (R) using the SQL pipe |>.

Code
new_age_query = """
    FROM new_view
    |> WHERE gender = 'male'
    |> SELECT name, age
"""

to be compared with

new_df |> 
    dplyr::filter(gender=='male') |>
    dplyr::select(name, age)
Code
men_df = (
    spark
        .sql(new_age_query)
        .show()
)
+----+---+
|name|age|
+----+---+
|John| 21|
+----+---+

SELECT (projection \(π\))

Code
df.createOrReplaceTempView("table")    

query = """
    SELECT 
        name, age 
    FROM 
        table
"""

spark.sql(query).show()
+----+---+
|name|age|
+----+---+
|John| 21|
|Jane| 25|
+----+---+

Using the API:

Code
(
    df
        .select("name", "age")
        .show()
)
+----+---+
|name|age|
+----+---+
|John| 21|
|Jane| 25|
+----+---+

π(df, "name", "age")

WHERE (filter, selection, \(σ\))

Code
df.createOrReplaceTempView("table")

query = """
    SELECT 
        * 
    FROM 
        table
    WHERE 
        age > 21
"""

query = """
    FROM table 
    |> WHERE age > 21 
    |> SELECT *  
"""

spark.sql(query).show()
+----+---+------+
|name|age|gender|
+----+---+------+
|Jane| 25|female|
+----+---+------+

Note that you can get rid of |> SELECT *

Using the API

Code
( 
    df
        .where("age > 21")
        .show()
)
+----+---+------+
|name|age|gender|
+----+---+------+
|Jane| 25|female|
+----+---+------+

This implements σ(df, "age > 21")

The where() method takes different types of inputs as argument: strings that can be interpreted as SQL conditions, but also boolean masks.

Code
# Alternatively:
( 
    df
      .where(df['age'] > 21)
      .show()
)
+----+---+------+
|name|age|gender|
+----+---+------+
|Jane| 25|female|
+----+---+------+

or

Code
( 
    df
      .where(df.age > 21)
      .show()
)
+----+---+------+
|name|age|gender|
+----+---+------+
|Jane| 25|female|
+----+---+------+

Where (and how) is the boolean mask built?

Method chaining allows to construct complex queries

Code
( 
    df
      .where("age > 21")
      .select(["name", "age"])
      .show()
)
+----+---+
|name|age|
+----+---+
|Jane| 25|
+----+---+

This implements

    σ(df, "age > 21") |>
    π(["name", "age"])

LIMIT

Code
df.createOrReplaceTempView("table")

query = """
    SELECT 
        * 
    FROM 
        table 
    LIMIT 1
"""

query = """
    FROM table
    |> LIMIT 1
"""

spark.sql(query).show()
+----+---+------+
|name|age|gender|
+----+---+------+
|John| 21|  male|
+----+---+------+
Code
df.limit(1).show()
+----+---+------+
|name|age|gender|
+----+---+------+
|John| 21|  male|
+----+---+------+
Code
df.select("*").limit(1).show()
+----+---+------+
|name|age|gender|
+----+---+------+
|John| 21|  male|
+----+---+------+

ORDER BY

Code
df.createOrReplaceTempView("table")

query = """
    SELECT 
        * 
    FROM 
        table
    ORDER BY 
        name ASC
"""

query = """
    FROM table
    |> ORDER BY name ASC
"""

spark.sql(query).show()
+----+---+------+
|name|age|gender|
+----+---+------+
|Jane| 25|female|
|John| 21|  male|
+----+---+------+

With the API

Code
df.orderBy(df.name.asc()).show()
+----+---+------+
|name|age|gender|
+----+---+------+
|Jane| 25|female|
|John| 21|  male|
+----+---+------+

ALIAS (rename)

Code
df.createOrReplaceTempView("table")

query = """
    SELECT 
        name, age, gender  AS sex 
    FROM 
        table
"""

query = """
    FROM table
    |> SELECT name, age, gender  AS sex
"""

spark.sql(query).show()
+----+---+------+
|name|age|   sex|
+----+---+------+
|John| 21|  male|
|Jane| 25|female|
+----+---+------+
Code
type(df.age)
pyspark.sql.classic.column.Column
Code
(
    df.select(
        df.name, 
        df.age, 
        df.gender.alias('sex'))
      .show()
)
+----+---+------+
|name|age|   sex|
+----+---+------+
|John| 21|  male|
|Jane| 25|female|
+----+---+------+

CAST

Code
df.createOrReplaceTempView("table")

query = """
    SELECT 
        name, 
        cast(age AS float) AS age_f 
    FROM 
        table
"""

query = """
    FROM table |>  
    SELECT 
        name, 
        cast(age AS float) AS age_f 
"""

spark.sql(query).show()
+----+-----+
|name|age_f|
+----+-----+
|John| 21.0|
|Jane| 25.0|
+----+-----+
Code
(
    df
        .select(
            df.name, 
            df.age
                .cast("float")
                .alias("age_f"))
        .show()
)
+----+-----+
|name|age_f|
+----+-----+
|John| 21.0|
|Jane| 25.0|
+----+-----+
Code
new_age_col = df.age.cast("float").alias("age_f")
type(new_age_col), type(df.age)
(pyspark.sql.classic.column.Column, pyspark.sql.classic.column.Column)
Code
df.select(df.name, new_age_col).show()
+----+-----+
|name|age_f|
+----+-----+
|John| 21.0|
|Jane| 25.0|
+----+-----+

Adding new columns

Code
df.createOrReplaceTempView("table")

query = """
    FROM table |>
    SELECT 
        *, 
        12*age AS age_months 
"""

spark.sql(query).show()
+----+---+------+----------+
|name|age|gender|age_months|
+----+---+------+----------+
|John| 21|  male|       252|
|Jane| 25|female|       300|
+----+---+------+----------+
Code
( 
    df
        .withColumn("age_months", df.age * 12)
        .show()
)
+----+---+------+----------+
|name|age|gender|age_months|
+----+---+------+----------+
|John| 21|  male|       252|
|Jane| 25|female|       300|
+----+---+------+----------+
Code
(
    df
        .select("*", 
                (df.age * 12).alias("age_months"))
        .show()
)
+----+---+------+----------+
|name|age|gender|age_months|
+----+---+------+----------+
|John| 21|  male|       252|
|Jane| 25|female|       300|
+----+---+------+----------+
Code
import datetime

hui = datetime.date.today()

str(hui)
'2026-02-08'
Code
( 
    df
     .withColumn("yob", datetime.date.today().year - df.age)
     .show()
)
+----+---+------+----+
|name|age|gender| yob|
+----+---+------+----+
|John| 21|  male|2005|
|Jane| 25|female|2001|
+----+---+------+----+

Column functions

Numeric functions examples

Code
from pyspark.sql import functions as fn

columns = ["brand", "cost"]

df = spark.createDataFrame(
    [("garnier", 3.49),
     ("elseve", 2.71)], 
    columns
)

round_cost = fn.round(df.cost, 1)
floor_cost = fn.floor(df.cost)
ceil_cost = fn.ceil(df.cost)

(
    df
    .withColumn('round', round_cost)
    .withColumn('floor', floor_cost)
    .withColumn('ceil', ceil_cost)
    .show()
)
+-------+----+-----+-----+----+
|  brand|cost|round|floor|ceil|
+-------+----+-----+-----+----+
|garnier|3.49|  3.5|    3|   4|
| elseve|2.71|  2.7|    2|   3|
+-------+----+-----+-----+----+

String functions examples

Code
from pyspark.sql import functions as fn

columns = ["first_name", "last_name"]

df = spark.createDataFrame([
    ("John", "Doe"),
    ("Mary", "Jane")
], columns)

last_name_initial = fn.substring(df.last_name, 0, 1)
# last_name_initial_dotted = fn.concat(last_name_initial, ".")

name = fn.concat_ws(" ", df.first_name, last_name_initial)

df.withColumn("name", name).show()
+----------+---------+------+
|first_name|last_name|  name|
+----------+---------+------+
|      John|      Doe|John D|
|      Mary|     Jane|Mary J|
+----------+---------+------+
Code
( 
    df.selectExpr("*", "substring(last_name, 0, 1) as lni")
      .selectExpr("first_name", "last_name", "concat(first_name, ' ', lni, '.') as nname")
      .show()
)
+----------+---------+-------+
|first_name|last_name|  nname|
+----------+---------+-------+
|      John|      Doe|John D.|
|      Mary|     Jane|Mary J.|
+----------+---------+-------+

As an SQL query

Code
df.createOrReplaceTempView("table")

query = """
    FROM table |>
    SELECT *, substring(last_name, 0, 1) AS lni |>
    SELECT first_name, last_name, concat(first_name, ' ', lni, '.') AS nname
"""
spark.sql(query).show()
+----------+---------+-------+
|first_name|last_name|  nname|
+----------+---------+-------+
|      John|      Doe|John D.|
|      Mary|     Jane|Mary J.|
+----------+---------+-------+

Spark SQL offers a large subsets of SQL functions

Date functions examples

Code
from datetime import date
from pyspark.sql import functions as fn

df = spark.createDataFrame([
    (date(2015, 1, 1), date(2015, 1, 15)),
    (date(2015, 2, 21), date(2015, 3, 8)),
], ["start_date", "end_date"])

days_between = fn.datediff(df.end_date, df.start_date)
start_month = fn.month(df.start_date)

(
    df
        .withColumn('days_between', days_between)
        .withColumn('start_month', start_month)
        .show()
)
+----------+----------+------------+-----------+
|start_date|  end_date|days_between|start_month|
+----------+----------+------------+-----------+
|2015-01-01|2015-01-15|          14|          1|
|2015-02-21|2015-03-08|          15|          2|
+----------+----------+------------+-----------+

Note that days_between is an instance of Column, the Spark type for columns in Spark dataframes.

Code
# %%
type(days_between)
pyspark.sql.classic.column.Column

Recall the datetime calculus available in Database systems and in many programming framework.

Code
str(date(2015, 1, 1) - date(2015, 1, 15))
'-14 days, 0:00:00'
Code
from datetime import timedelta

date(2023, 2 , 14) + timedelta(days=3)
datetime.date(2023, 2, 17)

Conditional transformations

Code
df = spark.createDataFrame([
    ("John", 21, "male"),
    ("Jane", 25, "female"),
    ("Albert", 46, "male"),
    ("Brad", 49, "super-hero")
], ["name", "age", "gender"])
Code
supervisor = ( 
    fn.when(df.gender == 'male', 'Mr. Smith')
      .when(df.gender == 'female', 'Miss Jones')
      .otherwise('NA')
)

type(supervisor), type(fn.when)
(pyspark.sql.classic.column.Column, function)
Code
(
    df
        .withColumn("supervisor", supervisor)
        .show()
)
+------+---+----------+----------+
|  name|age|    gender|supervisor|
+------+---+----------+----------+
|  John| 21|      male| Mr. Smith|
|  Jane| 25|    female|Miss Jones|
|Albert| 46|      male| Mr. Smith|
|  Brad| 49|super-hero|        NA|
+------+---+----------+----------+

User-defined functions

Code
from pyspark.sql import functions as fn
from pyspark.sql.types import StringType

df = spark.createDataFrame([(1, 3), (4, 2)], ["first", "second"])

def my_func(col_1, col_2):
    if (col_1 > col_2):
        return "{} is bigger than {}".format(col_1, col_2)
    else:
        return "{} is bigger than {}".format(col_2, col_1)

# registration
my_udf = fn.udf(my_func, StringType())

# fn.udf is a decorator
# it can be used in a a more explicit way

@fn.udf(returnType=StringType())
def the_same_func(col_1, col_2):
    if (col_1 > col_2):
        return "{} is bigger than {}".format(col_1, col_2)
    else:
        return "{} is bigger than {}".format(col_2, col_1)

# at work        
(
    df
        .withColumn("udf", my_udf(df['first'], df['second']))
        .withColumn("udf_too", the_same_func(df['first'], df['second']))
        .show()
)
[Stage 101:>                                                        (0 + 1) / 1]                                                                                
+-----+------+------------------+------------------+
|first|second|               udf|           udf_too|
+-----+------+------------------+------------------+
|    1|     3|3 is bigger than 1|3 is bigger than 1|
|    4|     2|4 is bigger than 2|4 is bigger than 2|
+-----+------+------------------+------------------+
TipUsing UDF in SQL queries

The user-defined-function (UDF) can also be used on SQL queries provided the decorated function is registered.

Code
df.createOrReplaceTempView("table")

spark.udf.register("the_same_func", the_same_func)  # the_same_func from @udf above
spark.sql("SELECT *, the_same_func(first, second) AS udf FROM table").show()
+-----+------+------------------+
|first|second|               udf|
+-----+------+------------------+
|    1|     3|3 is bigger than 1|
|    4|     2|4 is bigger than 2|
+-----+------+------------------+

Beware. spark.udf.register is not a decorator.

ImportantMore on UDF in Spark 4

Joins (\(⋈\))

Using the spark.sql API

Code
from datetime import date

products = spark.createDataFrame(
    [
        ('1', 'mouse', 'microsoft', 39.99),
        ('2', 'keyboard', 'logitech', 59.99),
    ], 
    ['prod_id', 'prod_cat', 'prod_brand', 'prod_value']
)

purchases = spark.createDataFrame([
    (date(2017, 11, 1), 2, '1'),
    (date(2017, 11, 2), 1, '1'),
    (date(2017, 11, 5), 1, '2'),
], ['date', 'quantity', 'prod_id'])

# The default join type is the "INNER" join
purchases.join(products, 'prod_id').show()
+-------+----------+--------+--------+----------+----------+
|prod_id|      date|quantity|prod_cat|prod_brand|prod_value|
+-------+----------+--------+--------+----------+----------+
|      1|2017-11-01|       2|   mouse| microsoft|     39.99|
|      1|2017-11-02|       1|   mouse| microsoft|     39.99|
|      2|2017-11-05|       1|keyboard|  logitech|     59.99|
+-------+----------+--------+--------+----------+----------+

Just as in RDBMs, we can ask for explanations:

Code
purchases.join(products, 'prod_id').explain()
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [prod_id#423, date#421, quantity#422L, prod_cat#418, prod_brand#419, prod_value#420]
   +- SortMergeJoin [prod_id#423], [prod_id#417], Inner
      :- Sort [prod_id#423 ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(prod_id#423, 200), ENSURE_REQUIREMENTS, [plan_id=648]
      :     +- Filter isnotnull(prod_id#423)
      :        +- Scan ExistingRDD[date#421,quantity#422L,prod_id#423]
      +- Sort [prod_id#417 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(prod_id#417, 200), ENSURE_REQUIREMENTS, [plan_id=649]
            +- Filter isnotnull(prod_id#417)
               +- Scan ExistingRDD[prod_id#417,prod_cat#418,prod_brand#419,prod_value#420]

Have a look at the PostgreSQL documentation. What are the different join methods? When is one favored?

Using a SQL query

Code
products.createOrReplaceTempView("products")
purchases.createOrReplaceTempView("purchases")

query = """
    SELECT * 
    FROM purchases AS prc INNER JOIN 
        products AS prd 
    ON prc.prod_id = prd.prod_id
"""

query = """
    FROM purchases |>
    INNER JOIN products USING(prod_id)    
"""

spark.sql(query).show()
+-------+----------+--------+--------+----------+----------+
|prod_id|      date|quantity|prod_cat|prod_brand|prod_value|
+-------+----------+--------+--------+----------+----------+
|      1|2017-11-01|       2|   mouse| microsoft|     39.99|
|      1|2017-11-02|       1|   mouse| microsoft|     39.99|
|      2|2017-11-05|       1|keyboard|  logitech|     59.99|
+-------+----------+--------+--------+----------+----------+
Code
spark.sql(query).explain()
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [prod_id#423, date#421, quantity#422L, prod_cat#418, prod_brand#419, prod_value#420]
   +- SortMergeJoin [prod_id#423], [prod_id#417], Inner
      :- Sort [prod_id#423 ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(prod_id#423, 200), ENSURE_REQUIREMENTS, [plan_id=798]
      :     +- Filter isnotnull(prod_id#423)
      :        +- Scan ExistingRDD[date#421,quantity#422L,prod_id#423]
      +- Sort [prod_id#417 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(prod_id#417, 200), ENSURE_REQUIREMENTS, [plan_id=799]
            +- Filter isnotnull(prod_id#417)
               +- Scan ExistingRDD[prod_id#417,prod_cat#418,prod_brand#419,prod_value#420]

Code
new_purchases = spark.createDataFrame([
    (date(2017, 11, 1), 2, '1'),
    (date(2017, 11, 2), 1, '3'),
], ['date', 'quantity', 'prod_id_x'])

# The default join type is the "INNER" join
join_rule = new_purchases.prod_id_x == products.prod_id

print(type(join_rule))

new_purchases.join(products, join_rule, 'left').show()
<class 'pyspark.sql.classic.column.Column'>
+----------+--------+---------+-------+--------+----------+----------+
|      date|quantity|prod_id_x|prod_id|prod_cat|prod_brand|prod_value|
+----------+--------+---------+-------+--------+----------+----------+
|2017-11-01|       2|        1|      1|   mouse| microsoft|     39.99|
|2017-11-02|       1|        3|   NULL|    NULL|      NULL|      NULL|
+----------+--------+---------+-------+--------+----------+----------+

What is the type of join_rule.info?

Code
join_rule.info
Column<'=(prod_id_x, prod_id)['info']'>
Code
new_purchases = spark.createDataFrame([
    (date(2017, 11, 1), 2, '1'),
    (date(2017, 11, 2), 1, '3'),
], ['date', 'quantity', 'prod_id_x'])

# The default join type is the "INNER" join
join_rule = new_purchases.prod_id_x == products.prod_id

new_purchases.join(products, join_rule, 'left').show()
+----------+--------+---------+-------+--------+----------+----------+
|      date|quantity|prod_id_x|prod_id|prod_cat|prod_brand|prod_value|
+----------+--------+---------+-------+--------+----------+----------+
|2017-11-01|       2|        1|      1|   mouse| microsoft|     39.99|
|2017-11-02|       1|        3|   NULL|    NULL|      NULL|      NULL|
+----------+--------+---------+-------+--------+----------+----------+
NoteDifferent kinds of joins

Just as in RDBMS, there are different kinds of joins.

They differ in the way tuples from left and the right tables are matched (is it a natural join, an equi-join, a \(\theta\) join?). They also differ in the way they handle non-matching tuples (inner or outer joins).

The join method has three parameters:

  • other: the right table
  • on: the join rule that defines how tuples are matched.
  • how: defines the way non-matching tuples are handled.

Various types of joins

Code
left = spark.createDataFrame([
    (1, "A1"), (2, "A2"), (3, "A3"), (4, "A4")], 
    ["id", "value"])

right = spark.createDataFrame([
    (3, "A3"), (4, "A4"), (4, "A4_1"), (5, "A5"), (6, "A6")], 
    ["id", "value"])

join_types = [
    "inner", "outer", "left", "right",
    "leftsemi", "leftanti"
]
Code
for join_type in join_types:
    print(join_type)
    left.join(right, on="id", how=join_type)\
        .orderBy("id")\
        .show()
inner
+---+-----+-----+
| id|value|value|
+---+-----+-----+
|  3|   A3|   A3|
|  4|   A4|   A4|
|  4|   A4| A4_1|
+---+-----+-----+

outer
+---+-----+-----+
| id|value|value|
+---+-----+-----+
|  1|   A1| NULL|
|  2|   A2| NULL|
|  3|   A3|   A3|
|  4|   A4|   A4|
|  4|   A4| A4_1|
|  5| NULL|   A5|
|  6| NULL|   A6|
+---+-----+-----+

left
+---+-----+-----+
| id|value|value|
+---+-----+-----+
|  1|   A1| NULL|
|  2|   A2| NULL|
|  3|   A3|   A3|
|  4|   A4| A4_1|
|  4|   A4|   A4|
+---+-----+-----+

right
+---+-----+-----+
| id|value|value|
+---+-----+-----+
|  3|   A3|   A3|
|  4|   A4|   A4|
|  4|   A4| A4_1|
|  5| NULL|   A5|
|  6| NULL|   A6|
+---+-----+-----+

leftsemi
+---+-----+
| id|value|
+---+-----+
|  3|   A3|
|  4|   A4|
+---+-----+

leftanti
+---+-----+
| id|value|
+---+-----+
|  1|   A1|
|  2|   A2|
+---+-----+

Agregations (summarize)

Examples using the API

Code
from pyspark.sql import functions as fn

products = spark.createDataFrame([
    ('1', 'mouse', 'microsoft', 39.99),
    ('2', 'mouse', 'microsoft', 59.99),
    ('3', 'keyboard', 'microsoft', 59.99),
    ('4', 'keyboard', 'logitech', 59.99),
    ('5', 'mouse', 'logitech', 29.99),
], ['prod_id', 'prod_cat', 'prod_brand', 'prod_value'])

( 
    products
        .groupBy('prod_cat')
        .avg('prod_value')
        .show()
)
+--------+-----------------+
|prod_cat|  avg(prod_value)|
+--------+-----------------+
|   mouse|43.32333333333333|
|keyboard|            59.99|
+--------+-----------------+

What is the type of products .groupBy('prod_cat')?

Code
(
    products
        .groupBy('prod_cat')
        .agg(fn.avg('prod_value'))
        .show()
)
+--------+-----------------+
|prod_cat|  avg(prod_value)|
+--------+-----------------+
|   mouse|43.32333333333333|
|keyboard|            59.99|
+--------+-----------------+
Code
(
    products
        .groupBy('prod_cat')
        .agg(
            fn.mean('prod_value'), 
            fn.stddev('prod_value')
        )
        .show()
)
+--------+-----------------+------------------+
|prod_cat|  avg(prod_value)|stddev(prod_value)|
+--------+-----------------+------------------+
|   mouse|43.32333333333333|15.275252316519468|
|keyboard|            59.99|               0.0|
+--------+-----------------+------------------+
Code
from pyspark.sql import functions as fn

(
    products
        .groupBy('prod_brand', 'prod_cat')\
        .agg(
            fn.avg('prod_value')
        )
        .show()
)
+----------+--------+---------------+
|prod_brand|prod_cat|avg(prod_value)|
+----------+--------+---------------+
| microsoft|   mouse|          49.99|
| microsoft|keyboard|          59.99|
|  logitech|keyboard|          59.99|
|  logitech|   mouse|          29.99|
+----------+--------+---------------+
Code
from pyspark.sql import functions as fn

(
    products
        .groupBy('prod_brand')
        .agg(
            fn.round(
                fn.avg('prod_value'), 1)
                .alias('average'),
            fn.ceil(
                fn.sum('prod_value'))
                .alias('sum'),
            fn.min('prod_value')
                .alias('min')
        )
        .show()
)
+----------+-------+---+-----+
|prod_brand|average|sum|  min|
+----------+-------+---+-----+
| microsoft|   53.3|160|39.99|
|  logitech|   45.0| 90|29.99|
+----------+-------+---+-----+

Example using a query

Code
products.createOrReplaceTempView("products")
Code
query = """
SELECT
    prod_brand,
    round(avg(prod_value), 1) AS average,
    min(prod_value) AS min
FROM 
    products
GROUP BY 
    prod_brand
"""

query = """
    FROM products 
    |> AGGREGATE 
        round(avg(prod_value), 1) AS average,
        min(prod_value) AS min GROUP BY prod_brand 
"""

spark.sql(query).show()
+----------+-------+-----+
|prod_brand|average|  min|
+----------+-------+-----+
| microsoft|   53.3|39.99|
|  logitech|   45.0|29.99|
+----------+-------+-----+
Code
spark.sql(query).explain()
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[prod_brand#573], functions=[avg(prod_value#574), min(prod_value#574)])
   +- Exchange hashpartitioning(prod_brand#573, 200), ENSURE_REQUIREMENTS, [plan_id=2043]
      +- HashAggregate(keys=[prod_brand#573], functions=[partial_avg(prod_value#574), partial_min(prod_value#574)])
         +- Project [prod_brand#573, prod_value#574]
            +- Scan ExistingRDD[prod_id#571,prod_cat#572,prod_brand#573,prod_value#574]

Window functions

Numerical window functions

Code
from pyspark.sql import Window
from pyspark.sql import functions as fn

# First, we create the Window definition
window = Window.partitionBy('prod_brand')

print(type(window))
<class 'pyspark.sql.classic.window.WindowSpec'>

Then, we can use over to aggregate on this window

Code
avg = fn.avg('prod_value').over(window)

# Finally, we can it as a classical column
(
    products
        .withColumn('avg_brand_value', fn.round(avg, 2))
        .show()
)
+-------+--------+----------+----------+---------------+
|prod_id|prod_cat|prod_brand|prod_value|avg_brand_value|
+-------+--------+----------+----------+---------------+
|      4|keyboard|  logitech|     59.99|          44.99|
|      5|   mouse|  logitech|     29.99|          44.99|
|      1|   mouse| microsoft|     39.99|          53.32|
|      2|   mouse| microsoft|     59.99|          53.32|
|      3|keyboard| microsoft|     59.99|          53.32|
+-------+--------+----------+----------+---------------+

With SQL queries, using windows ?

Code
query = """
    SELECT 
        *, 
        ROUND(AVG(prod_value) OVER w1, 2)  AS avg_brand_value,
        ROUND(AVG(prod_value) OVER w2, 1)  AS avg_prod_value
    FROM 
        products
    WINDOW 
        w1 AS (PARTITION BY prod_brand),
        w2 AS (PARTITION BY prod_cat)
"""


query2 = """
    FROM products |>
    SELECT
        *,  
        ROUND(AVG(prod_value) OVER w1, 2)  AS avg_brand_value,
        ROUND(AVG(prod_value) OVER w2, 1)  AS avg_prod_value
        WINDOW 
            w1 AS (PARTITION BY prod_brand),
            w2 AS (PARTITION BY prod_cat)
"""

spark.sql(query2).show()
+-------+--------+----------+----------+---------------+--------------+
|prod_id|prod_cat|prod_brand|prod_value|avg_brand_value|avg_prod_value|
+-------+--------+----------+----------+---------------+--------------+
|      4|keyboard|  logitech|     59.99|          44.99|          60.0|
|      3|keyboard| microsoft|     59.99|          53.32|          60.0|
|      5|   mouse|  logitech|     29.99|          44.99|          43.3|
|      1|   mouse| microsoft|     39.99|          53.32|          43.3|
|      2|   mouse| microsoft|     59.99|          53.32|          43.3|
+-------+--------+----------+----------+---------------+--------------+
Code
window2 = Window.partitionBy('prod_cat')

avg2 = fn.avg('prod_value').over(window2)

# Finally, we can do it as a classical column
( 
    products
        .withColumn('avg_brand_value', fn.round(avg, 2))
        .withColumn('avg_prod_value', fn.round(avg2, 1))
        .show()
)
+-------+--------+----------+----------+---------------+--------------+
|prod_id|prod_cat|prod_brand|prod_value|avg_brand_value|avg_prod_value|
+-------+--------+----------+----------+---------------+--------------+
|      4|keyboard|  logitech|     59.99|          44.99|          60.0|
|      3|keyboard| microsoft|     59.99|          53.32|          60.0|
|      5|   mouse|  logitech|     29.99|          44.99|          43.3|
|      1|   mouse| microsoft|     39.99|          53.32|          43.3|
|      2|   mouse| microsoft|     59.99|          53.32|          43.3|
+-------+--------+----------+----------+---------------+--------------+

Now we can compare the physical plans associated with the two jobs.

Code
( 
    products
        .withColumn('avg_brand_value', fn.round(avg, 2))
        .withColumn('avg_prod_value', fn.round(avg2, 1))
        .explain()
)
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [prod_id#571, prod_cat#572, prod_brand#573, prod_value#574, avg_brand_value#846, round(_we0#851, 1) AS avg_prod_value#849]
   +- Window [avg(prod_value#574) windowspecdefinition(prod_cat#572, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#851], [prod_cat#572]
      +- Sort [prod_cat#572 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(prod_cat#572, 200), ENSURE_REQUIREMENTS, [plan_id=2384]
            +- Project [prod_id#571, prod_cat#572, prod_brand#573, prod_value#574, round(_we0#848, 2) AS avg_brand_value#846]
               +- Window [avg(prod_value#574) windowspecdefinition(prod_brand#573, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#848], [prod_brand#573]
                  +- Sort [prod_brand#573 ASC NULLS FIRST], false, 0
                     +- Exchange hashpartitioning(prod_brand#573, 200), ENSURE_REQUIREMENTS, [plan_id=2379]
                        +- Scan ExistingRDD[prod_id#571,prod_cat#572,prod_brand#573,prod_value#574]

Code
spark.sql(query).explain()
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [prod_id#571, prod_cat#572, prod_brand#573, prod_value#574, round(_we0#856, 2) AS avg_brand_value#852, round(_we1#857, 1) AS avg_prod_value#853]
   +- Window [avg(prod_value#574) windowspecdefinition(prod_cat#572, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we1#857], [prod_cat#572]
      +- Sort [prod_cat#572 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(prod_cat#572, 200), ENSURE_REQUIREMENTS, [plan_id=2408]
            +- Window [avg(prod_value#574) windowspecdefinition(prod_brand#573, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#856], [prod_brand#573]
               +- Sort [prod_brand#573 ASC NULLS FIRST], false, 0
                  +- Exchange hashpartitioning(prod_brand#573, 200), ENSURE_REQUIREMENTS, [plan_id=2404]
                     +- Scan ExistingRDD[prod_id#571,prod_cat#572,prod_brand#573,prod_value#574]

Windows can be defined on multiple columns

Code
from pyspark.sql import Window
from pyspark.sql import functions as fn

window = Window.partitionBy('prod_brand', 'prod_cat')

avg = fn.avg('prod_value').over(window)


(
    products    
        .withColumn('avg_value', fn.round(avg, 2))
        .show()
)
+-------+--------+----------+----------+---------+
|prod_id|prod_cat|prod_brand|prod_value|avg_value|
+-------+--------+----------+----------+---------+
|      4|keyboard|  logitech|     59.99|    59.99|
|      5|   mouse|  logitech|     29.99|    29.99|
|      3|keyboard| microsoft|     59.99|    59.99|
|      1|   mouse| microsoft|     39.99|    49.99|
|      2|   mouse| microsoft|     59.99|    49.99|
+-------+--------+----------+----------+---------+

Lag and Lead

Code
purchases = spark.createDataFrame(
    [
        (date(2017, 11, 1), 'mouse'),
        (date(2017, 11, 2), 'mouse'),
        (date(2017, 11, 4), 'keyboard'),
        (date(2017, 11, 6), 'keyboard'),
        (date(2017, 11, 9), 'keyboard'),
        (date(2017, 11, 12), 'mouse'),
        (date(2017, 11, 18), 'keyboard')
    ], 
    ['date', 'prod_cat']
)

purchases.show()

window = Window.partitionBy('prod_cat').orderBy('date')

prev_purch = fn.lag('date', 1).over(window)
next_purch = fn.lead('date', 1).over(window)

purchases\
    .withColumn('prev', prev_purch)\
    .withColumn('next', next_purch)\
    .orderBy('prod_cat', 'date')\
    .show()
+----------+--------+
|      date|prod_cat|
+----------+--------+
|2017-11-01|   mouse|
|2017-11-02|   mouse|
|2017-11-04|keyboard|
|2017-11-06|keyboard|
|2017-11-09|keyboard|
|2017-11-12|   mouse|
|2017-11-18|keyboard|
+----------+--------+

+----------+--------+----------+----------+
|      date|prod_cat|      prev|      next|
+----------+--------+----------+----------+
|2017-11-04|keyboard|      NULL|2017-11-06|
|2017-11-06|keyboard|2017-11-04|2017-11-09|
|2017-11-09|keyboard|2017-11-06|2017-11-18|
|2017-11-18|keyboard|2017-11-09|      NULL|
|2017-11-01|   mouse|      NULL|2017-11-02|
|2017-11-02|   mouse|2017-11-01|2017-11-12|
|2017-11-12|   mouse|2017-11-02|      NULL|
+----------+--------+----------+----------+

Rank, DenseRank and RowNumber

Code
contestants = spark.createDataFrame(
    [   
        ('veterans', 'John', 3000),
        ('veterans', 'Bob', 3200),
        ('veterans', 'Mary', 4000),
        ('young', 'Jane', 4000),
        ('young', 'April', 3100),
        ('young', 'Alice', 3700),
        ('young', 'Micheal', 4000),
    ], 
    ['category', 'name', 'points']
)

contestants.show()
+--------+-------+------+
|category|   name|points|
+--------+-------+------+
|veterans|   John|  3000|
|veterans|    Bob|  3200|
|veterans|   Mary|  4000|
|   young|   Jane|  4000|
|   young|  April|  3100|
|   young|  Alice|  3700|
|   young|Micheal|  4000|
+--------+-------+------+
Code
window = (
    Window
        .partitionBy('category')
        .orderBy(contestants.points.desc())
)

rank = fn.rank().over(window)
dense_rank = fn.dense_rank().over(window)
row_number = fn.row_number().over(window)

(
contestants
    .withColumn('rank', rank)
    .withColumn('dense_rank', dense_rank)
    .withColumn('row_number', row_number)
    .orderBy('category', fn.col('points').desc())
    .show()
)
+--------+-------+------+----+----------+----------+
|category|   name|points|rank|dense_rank|row_number|
+--------+-------+------+----+----------+----------+
|veterans|   Mary|  4000|   1|         1|         1|
|veterans|    Bob|  3200|   2|         2|         2|
|veterans|   John|  3000|   3|         3|         3|
|   young|   Jane|  4000|   1|         1|         1|
|   young|Micheal|  4000|   1|         1|         2|
|   young|  Alice|  3700|   3|         2|         3|
|   young|  April|  3100|   4|         3|         4|
+--------+-------+------+----+----------+----------+
Code
spark.stop()