PySpark Performance Optimization : Guide to Fast, Scalable Big Data Pipelines

Introduction: Why PySpark Optimization Matters

Apache Spark is one of the most powerful distributed computing frameworks ever built. Yet even experienced engineers routinely leave 60–80% of cluster performance on the table — suffering multi-hour runtimes, out-of-memory crashes, and runaway cloud bills — simply because they haven’t applied the right set of optimizations.

This guide covers 10 of the most impactful PySpark optimization techniques, from low-hanging fruit like caching and broadcast joins to advanced strategies like Adaptive Query Execution and data skew handling. Each technique is explained with clear architecture context, working code examples, and guidance on when to apply it.


📌 SCOPE
This guide targets PySpark 3.x on Spark SQL / DataFrame API. Most techniques also apply to Spark 2.4+ with minor syntax differences.

Quick Reference: All 10 Optimization Techniques

TechniqueCategoryPerformance GainComplexity
Caching & PersistenceMemory Management2x – 10xLow
Broadcast JoinJoin Strategy5x – 50xLow
BucketingData Layout3x – 10xMedium
Adaptive Query Execution (AQE)Query Planner2x – 5xLow
Partitioning StrategyData Layout2x – 8xMedium
Predicate PushdownI/O Reduction2x – 20xLow
Serialization (Kryo)Memory/Network1.5x – 3xLow
UDF OptimizationComputation2x – 10xMedium
Data Skew HandlingResource Balance3x – 15xHigh
Resource TuningCluster Config2x – 5xMedium

1. Caching and Persistence

What is Caching in PySpark?

Every time you call an action on a PySpark DataFrame — count(), show(), write() — Spark re-executes the entire DAG from scratch. If your DataFrame is the product of expensive transformations (joins, aggregations, file reads), this means paying that cost repeatedly. Caching breaks this cycle by storing the computed DataFrame in memory (or on disk) so subsequent actions read from the cache rather than recomputing from source.

Storage Levels Explained

PySpark offers several storage levels that control where and how data is cached:

Storage LevelMemoryDiskSerializedReplicatedBest For
MEMORY_ONLYYesNoNoNoSmall DataFrames, fast access
MEMORY_AND_DISKYesYes (spill)NoNoMedium DataFrames
MEMORY_ONLY_SERYesNoYesNoReducing GC pressure
MEMORY_AND_DISK_SERYesYesYesNoLarge DataFrames
DISK_ONLYNoYesYesNoVery large, infrequent access
MEMORY_AND_DISK_2YesYesNoYesFault-tolerant caching

Caching vs Persisting

cache() is simply a shortcut for persist(StorageLevel.MEMORY_AND_DISK). Use persist() when you need explicit control over the storage level.

▸ When to Use cache()

  • DataFrame is used in 2 or more downstream actions or joins
  • The DataFrame is expensive to compute (multi-step joins, aggregations)
  • It fits comfortably in cluster memory (< 60% of total executor memory)

▸ When NOT to Cache

  • DataFrame is used only once — caching adds overhead with no benefit
  • DataFrame is very large and doesn’t fit in memory — causes GC pressure
  • ETL pipelines that run once — write directly instead of caching

Caching Code Examples

from pyspark import StorageLevel
# Simple cache (MEMORY_AND_DISK)
df_expensive = spark.read.parquet('/data/raw/events') \
.filter('event_type = "purchase"') \
.join(users_df, 'user_id') \
.groupBy('user_id').agg(F.sum('amount').alias('total'))
df_expensive.cache() # lazy — not computed yet
df_expensive.count() # triggers computation + cache fill
# Now these reuse the cache:
df_expensive.show(10) # no recomputation
df_expensive.write.parquet('/data/output/totals') # no recomputation
# Explicit storage level
df_large.persist(StorageLevel.MEMORY_AND_DISK_SER)
# Always unpersist when done to free memory
df_expensive.unpersist()

💡 TIP
Always call an action (like count()) immediately after cache() to trigger computation. Only then is the cache filled. Calling cache() alone does nothing until an action is triggered.

Checking Cache Status

# Check if a DataFrame is cached
print(df_expensive.is_cached) # True / False
# View cache storage info in Spark UI → Storage tab
# or programmatically:
spark.sparkContext._jvm.org.apache.spark.storage.StorageUtils \
.rddToDebugString(df_expensive.rdd._jrdd.rdd())

2. Broadcast Join

How Broadcast Join Works in PySpark

A standard Sort-Merge Join between two large tables requires both tables to be shuffled across the network — an enormously expensive operation. A Broadcast Join avoids this entirely by sending a copy of the smaller table to every executor node. Each executor then performs the join locally against its partition of the larger table, with zero network shuffle.

Automatic vs Manual Broadcast

Spark automatically broadcasts tables below the spark.sql.autoBroadcastJoinThreshold (default 10 MB). For larger tables that still qualify for broadcast, you can trigger it manually using broadcast().

▸ Auto-Broadcast Configuration

# Increase auto-broadcast threshold to 50 MB
spark.conf.set('spark.sql.autoBroadcastJoinThreshold', 50 * 1024 * 1024)
# Disable auto-broadcast (force SMJ for all joins)
spark.conf.set('spark.sql.autoBroadcastJoinThreshold', -1)

▸ Manual Broadcast Hint

from pyspark.sql.functions import broadcast
# Manually hint Spark to broadcast the smaller table
result = orders_df.join(
broadcast(products_df), # products_df is 200 MB — fits in executor RAM
on='product_id',
how='inner'
)
# Verify: explain() should show BroadcastHashJoin, NOT SortMergeJoin
result.explain(mode='formatted')

Broadcast Join Limits and Thresholds

Table SizeStrategyConfig
< 10 MBAuto BroadcastDefault threshold
10 MB – 200 MBManual BroadcastUse broadcast() hint
200 MB – 2 GBBucketingbucketBy() + saveAsTable()
> 2 GBSort-Merge JoinDefault — optimize with AQE

⚠️ WARN
Broadcasting a table too large for executor memory causes OOM errors. Monitor executor memory usage and stay within 30–40% of executor RAM as a safe limit for broadcast tables.

3. Smart Partitioning Strategy

Understanding PySpark Partitions

Partitions are the fundamental unit of parallelism in Spark. Every DataFrame is divided into partitions, and each partition is processed by a single task on a single core. Too few partitions means under-utilization of the cluster; too many partitions means excessive task scheduling overhead and tiny files.

The Partition Count Problem

▸ Default Partition Behavior

  • On read: — Spark creates 1 partition per HDFS block (~128 MB by default)
  • After shuffle: — Spark creates spark.sql.shuffle.partitions partitions (default: 200)
  • After repartition(): — Spark creates exactly the number you specify

Repartition vs Coalesce

MethodShuffle?Use CasePerformance
repartition(N)Yes (full shuffle)Increasing partitions / even distributionSlower write, faster downstream
coalesce(N)No (narrow)Decreasing partitions / reducing file countFaster, but may produce skewed partitions
repartitionByRange(N, col)YesRange-partitioned output for sorted readsMedium — best for ordered data

▸ Optimal Partition Size Formula

# Rule of thumb: each partition = 128 MB – 512 MB uncompressed
# Compute optimal partition count:
total_data_bytes = 500 * 1024**3 # 500 GB
target_partition_bytes = 256 * 1024**2 # 256 MB
optimal_partitions = total_data_bytes // target_partition_bytes
# Result: ~2000 partitions
# Set before shuffle-heavy operations
spark.conf.set('spark.sql.shuffle.partitions', optimal_partitions)
# Repartition a DataFrame explicitly
df_balanced = df.repartition(2000, 'customer_id')
# Reduce small files before writing
df_final = df_transformed.coalesce(200)
df_final.write.parquet('/data/output/result')

Write Partitioning for Storage Optimization

▸ partitionBy for Hive-style Directory Partitions

# Write with directory-level partitioning (Hive-style)
# Creates: /data/events/year=2024/month=01/day=15/part-*.parquet
df.write \
.partitionBy('year', 'month', 'day') \
.mode('overwrite') \
.parquet('/data/events')
# Query with partition pruning (only reads matching directories)
df_jan = spark.read.parquet('/data/events') \
.filter('year=2024 AND month=01') # Only scans year=2024/month=01/

💡 TIP
Partition on low-cardinality columns (date, region, status). High-cardinality partitioning (user_id) creates millions of tiny directories — use bucketing for those instead.

4. Adaptive Query Execution (AQE)

What is Adaptive Query Execution?

Traditional Spark planning is static — the query plan is fixed before execution begins, based on estimated statistics that are often wrong. Adaptive Query Execution (AQE), introduced in Spark 3.0, changes this by re-optimizing the query plan at runtime as actual data statistics become available after each shuffle stage.

AQE’s Three Core Features

▸ 1. Dynamic Coalescing of Shuffle Partitions

AQE merges small post-shuffle partitions automatically, reducing the overhead of thousands of tiny tasks without requiring you to manually set spark.sql.shuffle.partitions.

▸ 2. Dynamic Join Strategy Switching

AQE can switch a Sort-Merge Join to a Broadcast Hash Join at runtime if it determines one side is small enough to broadcast — even if the static plan chose SMJ.

▸ 3. Skew Join Optimization

AQE detects skewed partitions and automatically splits them into smaller sub-tasks, distributing the load evenly without manual salting.

Enabling and Configuring AQE

spark = SparkSession.builder \
.config('spark.sql.adaptive.enabled', 'true') # Enable AQE
.config('spark.sql.adaptive.coalescePartitions.enabled', 'true') # Auto-merge small partitions
.config('spark.sql.adaptive.coalescePartitions.minPartitionNum', '1') # Min after coalesce
.config('spark.sql.adaptive.advisoryPartitionSizeInBytes', '128MB') # Target partition size
.config('spark.sql.adaptive.skewJoin.enabled', 'true') # Auto skew handling
.config('spark.sql.adaptive.skewJoin.skewedPartitionFactor', '5') # 5x median = skewed
.config('spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes', '256MB')
.getOrCreate()


💡 TIP
AQE is enabled by default in Spark 3.2+. For Spark 3.0–3.1, add spark.sql.adaptive.enabled = true explicitly. It is the single highest-ROI configuration change with the lowest effort.

5. Predicate Pushdown and Column Pruning

What is Predicate Pushdown?

Predicate pushdown is the practice of applying filter conditions as early as possible in the query execution — ideally at the data source level, before data is loaded into memory. When reading Parquet or ORC files, Spark can pass filter predicates directly to the file reader, which skips entire row groups or files that do not match the condition.

Column Pruning

Column pruning complements predicate pushdown: instead of reading all columns from a file and discarding unused ones, Spark reads only the columns referenced in the query. In columnar formats like Parquet, this can reduce I/O by 80%+.

How to Maximize Pushdown

▸ Apply Filters Early in the Chain

# BAD — filter applied after expensive join
result = orders_df \
.join(customers_df, 'customer_id') \
.join(products_df, 'product_id') \
.filter("order_date >= '2024-01-01'") # Too late!
# GOOD — filter applied before join (Spark Catalyst also does this automatically)
result = orders_df \
.filter("order_date >= '2024-01-01'") \ # Pushes down to file scan
.join(customers_df, 'customer_id') \
.join(products_df, 'product_id')
# BEST — select only needed columns too (column pruning)
result = orders_df \
.select('order_id', 'customer_id', 'product_id', 'amount', 'order_date') \
.filter("order_date >= '2024-01-01'") \
.join(customers_df.select('customer_id', 'name'), 'customer_id') \
.join(products_df.select('product_id', 'category'), 'product_id')

▸ Verifying Pushdown in the Physical Plan

# Write data sorted on the filter column to maximize row group skipping
df.sort('order_date', 'amount') \
.write \
.option('parquet.block.size', 128 * 1024 * 1024) \
.parquet('/data/orders_sorted')

Parquet Statistics and Row Group Skipping

Parquet stores min/max statistics for each row group. When a filter condition (e.g., amount > 1000) is outside the min/max range of a row group, Spark skips that row group entirely without reading it. Sorting data before writing dramatically improves this:

# Write data sorted on the filter column to maximize row group skipping
df.sort('order_date', 'amount') \
.write \
.option('parquet.block.size', 128 * 1024 * 1024) \
.parquet('/data/orders_sorted')

6. Kryo Serialization

Why Serialization Matters in Spark

Spark serializes objects when shuffling data across the network, spilling to disk, and caching RDDs. Java’s default serialization is slow and produces large byte arrays. Kryo serialization is typically 10x faster and produces 2–5x smaller output — reducing shuffle I/O, disk spill, and GC pressure.

Enabling Kryo Serialization

spark = SparkSession.builder \
.config('spark.serializer', 'org.apache.spark.serializer.KryoSerializer') \
.config('spark.kryo.registrationRequired', 'false') # Recommended for dev
.config('spark.kryo.unsafe', 'true') # Faster but less safe
.getOrCreate()
# For production: register custom classes to maximize efficiency
spark.conf.set('spark.kryo.classesToRegister',
'com.mycompany.model.Order,com.mycompany.model.Customer')

▸ When Kryo Applies

  • RDD operations — full benefit (shuffles, caching, spilling)
  • DataFrame/Dataset API — limited benefit (Spark SQL uses its own encoder)
  • UDFs that return complex objects — significant benefit

💡 TIP
For pure DataFrame/SQL workloads, Kryo has minimal effect because Spark uses Tungsten’s binary encoding. The biggest gains come from RDD-based code and UDFs returning custom objects.

7. UDF Optimization

The Hidden Cost of Python UDFs

Python UDFs (User-Defined Functions) are the single biggest performance trap for PySpark developers. When you call a Python UDF, Spark must serialize each row from the JVM to a Python process using Pickle, execute the Python function, then deserialize the result back to the JVM. This round-trip is 5–100x slower than native Spark SQL functions.

UDF Performance Hierarchy

UDF TypeExecutionSpeedUse When
Built-in SQL functions (F.col, F.when, etc.)JVM nativeFastestAlways prefer — covers 90% of cases
Spark SQL expressions (selectExpr)JVM nativeFastestString-based transformations
Pandas UDF (vectorized)Arrow + PythonFastComplex logic on batches
Python UDF (row-by-row)Python pickleSlowOnly when no alternative exists
RDD map() with Python lambdaPython pickleSlowestAvoid in DataFrame workloads

Replacing Python UDFs with Built-in Functions

▸ Before: Slow Python UDF

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
# Slow Python UDF — serializes every row to Python
@udf(returnType=StringType())
def classify_spend(amount):
if amount is None: return 'unknown'
elif amount < 100: return 'low'
elif amount < 1000: return 'medium'
else: return 'high'
df = df.withColumn('spend_category', classify_spend('amount'))

▸ After: Fast Built-in Functions

from pyspark.sql.functions import when, col
# Fast — runs entirely in JVM, zero Python overhead
df = df.withColumn('spend_category',
when(col('amount').isNull(), 'unknown')
.when(col('amount') < 100, 'low')
.when(col('amount') < 1000, 'medium')
.otherwise('high')
)

When UDFs Are Unavoidable: Use Pandas UDFs

▸ Vectorized Pandas UDF (Arrow-accelerated)

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd
# Pandas UDF operates on entire column vectors via Apache Arrow
# ~10-100x faster than row-by-row Python UDF
@pandas_udf(DoubleType())
def compound_interest(principal: pd.Series, rate: pd.Series, years: pd.Series) -> pd.Series:
return principal * (1 + rate) ** years
df = df.withColumn('future_value',
compound_interest('principal', 'interest_rate', 'tenure_years')
)

8. Data Skew Handling

Understanding Data Skew in PySpark

Data skew occurs when some partitions contain significantly more data than others — typically because a small number of key values dominate the dataset. A join on a skewed column causes a few tasks to process 100x more data than others, while the rest of the cluster sits idle waiting for the stragglers. This is one of the hardest performance problems to diagnose and fix.

Diagnosing Data Skew

▸ Signs of Skew in Spark UI

  • Most tasks complete in seconds — but 1–5 tasks take hours
  • Stage completion bar shows a long tail of slow tasks
  • One executor’s memory usage is 10x higher than others

▸ Diagnosing with Code

# Check key distribution before joining
key_dist = df.groupBy('customer_id') \
.count() \
.orderBy('count', ascending=False)
key_dist.show(20) # If top 5 rows have millions of rows, you have skew
# Statistical profile of key distribution
key_dist.select(
F.min('count').alias('min_count'),
F.max('count').alias('max_count'),
F.mean('count').alias('avg_count'),
F.stddev('count').alias('std_count')
).show()

Fixing Skew: The Salting Technique

▸ How Salting Works

Salting adds a random suffix to skewed keys, distributing what was one hot partition into N smaller partitions. The join is then performed on the salted key, and results are aggregated back.

from pyspark.sql.functions import col, rand, floor, lit, concat_ws, explode, array
SALT_FACTOR = 20 # Distribute each hot key across 20 buckets
# Step 1: Salt the large (skewed) table
orders_salted = orders_df \
.withColumn('salt', (rand() * SALT_FACTOR).cast('int')) \
.withColumn('salted_key', concat_ws('_', col('customer_id'), col('salt')))
# Step 2: Explode the small (dimension) table to match all salt values
customers_salted = customers_df \
.withColumn('salt_array', array([lit(i) for i in range(SALT_FACTOR)])) \
.withColumn('salt', explode('salt_array')) \
.withColumn('salted_key', concat_ws('_', col('customer_id'), col('salt'))) \
.drop('salt_array', 'salt')
# Step 3: Join on salted key — no more hot partitions
result = orders_salted.join(customers_salted, 'salted_key', 'inner') \
.drop('salt', 'salted_key')

⚠️ WARN
Salting multiplies the size of the smaller table by SALT_FACTOR. For very large dimension tables, use AQE’s automatic skew handling or partial salting (only for skewed keys) instead.

9. Optimal File Formats and Compression

Choosing the Right File Format

FormatTypeSplittable?Schema EvolutionBest For
ParquetColumnar binaryYesYes (add columns)Analytics, OLAP, most workloads
ORCColumnar binaryYesLimitedHive-centric workloads, Hive ACID
Delta LakeParquet + tx logYesFull ACIDLakehouse, upserts, time travel
AvroRow binaryYesFullStreaming, Kafka integration
JSONRow textNo (gzip)Schema-freeDev/debug only — avoid in production
CSVRow textNo (gzip)NoneInput ingestion only — never store as CSV

Compression Codec Comparison

CodecRatioSpeedSplittableRecommended Use
SnappyMediumVery fastYes (Parquet/ORC)Default choice — best balance
ZstdHighFastYes (Parquet/ORC)Best for storage-constrained envs
GzipHighSlowNoAvoid for large files — not splittable
LZ4LowFastestYes (Parquet/ORC)Hot data, low-latency pipelines
UncompressedNoneN/AYesOnly for benchmarking
# Write optimal Parquet with Snappy (default)
df.write \
.option('compression', 'snappy') \
.parquet('/data/output/events')
# Write with Zstd for better compression (Parquet 2.x, Spark 3.2+)
df.write \
.option('compression', 'zstd') \
.option('parquet.compression.codec', 'zstd') \
.parquet('/data/output/events_compressed')

Avoiding the Small Files Problem

▸ Why Small Files Kill Performance

Writing thousands of tiny files (< 10 MB each) causes: slow metadata operations on the namenode, excessive task creation overhead, and poor read performance. Always coalesce before writing:

# Bad: 200 shuffle partitions = 200 tiny files
df.write.parquet('/data/output')
# Good: reduce to ~50 files of ~256 MB each
target_file_size_mb = 256
estimated_data_mb = df.count() * 500 / (1024 * 1024) # rough estimate
num_files = max(1, int(estimated_data_mb / target_file_size_mb))
df.coalesce(num_files).write.parquet('/data/output')

10. Executor and Memory Tuning

Spark Memory Architecture

Each Spark executor has a fixed memory budget divided into three main regions: storage memory (for cached DataFrames), execution memory (for shuffles, sorts, aggregations), and user memory (for your application objects). Tuning these ratios for your workload type dramatically improves throughput and eliminates OOM errors.

Critical Memory Configuration Parameters

ParameterDefaultRecommended (Batch)Effect
spark.executor.memory1g8g – 32gTotal executor JVM heap
spark.executor.cores14 – 5Tasks per executor (4–5 is sweet spot)
spark.executor.instancesautoCluster dependentTotal number of executors
spark.memory.fraction0.60.7 – 0.8% of heap for Spark execution+storage
spark.memory.storageFraction0.50.3 – 0.5% of Spark memory reserved for storage
spark.executor.memoryOverhead10%2g – 4gOff-heap memory (Python, native libs)
spark.driver.memory1g4g – 8gDriver heap (larger for collect/toPandas)
spark.sql.files.maxPartitionBytes128MB256MBMax partition size when reading files

Production Configuration Template

▸ For Memory-Intensive Analytics Workloads

spark = SparkSession.builder \
.appName('ProductionAnalytics') \
# Executor sizing (assuming 32-core, 256GB nodes)
.config('spark.executor.cores', '5') \
.config('spark.executor.memory', '32g') \
.config('spark.executor.memoryOverhead', '4g') \
.config('spark.driver.memory', '8g') \
.config('spark.driver.maxResultSize', '4g') \
# Memory management
.config('spark.memory.fraction', '0.75') \
.config('spark.memory.storageFraction', '0.4') \
# Shuffle optimization
.config('spark.sql.shuffle.partitions', '400') \
.config('spark.sql.files.maxPartitionBytes', str(256 * 1024 * 1024)) \
# AQE + adaptive tuning
.config('spark.sql.adaptive.enabled', 'true') \
.config('spark.sql.adaptive.coalescePartitions.enabled', 'true') \
# Serialization
.config('spark.serializer', 'org.apache.spark.serializer.KryoSerializer') \
.getOrCreate()

Off-Heap Memory for Python Workers

▸ PySpark-Specific Memory Considerations

  • Each Python worker process uses memory outside the JVM heap
  • Set spark.executor.memoryOverhead to at least 2 GB for PySpark jobs
  • For Pandas UDFs, add extra overhead per executor (Pandas DataFrames can be large)
  • Monitor GC time in Spark UI — if > 5% of task time, increase executor memory

Summary: PySpark Optimization Checklist

Production Optimization Checklist

Use this checklist before promoting any PySpark job to production:

Query & Join Optimization

  1. Apply filters and column selects before joins (predicate pushdown)
  2. Use broadcast() for dimension tables < 200 MB
  3. Verify shuffle elimination with explain() for bucketed joins
  4. Enable AQE: spark.sql.adaptive.enabled = true

Data Layout & I/O

  • Write data in Parquet or ORC format with Snappy or Zstd compression
  • Use partitionBy for time-series data with date/region filters
  • Coalesce before writing to avoid small files (target 128–512 MB per file)
  • Sort data on filter columns before writing for row group skipping

Memory & Resources

  • Set spark.executor.cores to 4–5 per executor
  • Set spark.executor.memoryOverhead to at least 2g for PySpark
  • Tune spark.sql.shuffle.partitions to match your data volume
  • Cache DataFrames reused in 2+ actions; always unpersist() when done

Code Quality

  1. Replace Python UDFs with built-in SQL functions wherever possible
  2. Use Pandas UDFs (vectorized) when Python logic is unavoidable
  3. Diagnose and salt skewed join keys before deploying at scale
  4. Enable Kryo serialization for RDD-heavy workloads

Final Thoughts

PySpark optimization is not a one-time activity — it is an ongoing discipline. The techniques in this guide range from zero-config wins (enabling AQE, choosing Parquet) to architectural decisions (bucketing, partitioning strategy) that shape how your entire data platform performs.

Start with the quick wins: AQE, broadcast joins, and predicate pushdown. Then layer in the architectural optimizations as your pipelines mature. Monitor the Spark UI obsessively — it tells you everything about where your jobs are spending time.


🎯 GOAL
The best-optimized Spark job is the one that reads the least data, shuffles the least data, and uses exactly as much memory as it needs — no more, no less.


Discover more from DataSangyan

Subscribe to get the latest posts sent to your email.

Leave a Reply