Introduction: Why PySpark Optimization Matters
Apache Spark is one of the most powerful distributed computing frameworks ever built. Yet even experienced engineers routinely leave 60–80% of cluster performance on the table — suffering multi-hour runtimes, out-of-memory crashes, and runaway cloud bills — simply because they haven’t applied the right set of optimizations.
This guide covers 10 of the most impactful PySpark optimization techniques, from low-hanging fruit like caching and broadcast joins to advanced strategies like Adaptive Query Execution and data skew handling. Each technique is explained with clear architecture context, working code examples, and guidance on when to apply it.
📌 SCOPE | This guide targets PySpark 3.x on Spark SQL / DataFrame API. Most techniques also apply to Spark 2.4+ with minor syntax differences. |
Quick Reference: All 10 Optimization Techniques
| Technique | Category | Performance Gain | Complexity |
| Caching & Persistence | Memory Management | 2x – 10x | Low |
| Broadcast Join | Join Strategy | 5x – 50x | Low |
| Bucketing | Data Layout | 3x – 10x | Medium |
| Adaptive Query Execution (AQE) | Query Planner | 2x – 5x | Low |
| Partitioning Strategy | Data Layout | 2x – 8x | Medium |
| Predicate Pushdown | I/O Reduction | 2x – 20x | Low |
| Serialization (Kryo) | Memory/Network | 1.5x – 3x | Low |
| UDF Optimization | Computation | 2x – 10x | Medium |
| Data Skew Handling | Resource Balance | 3x – 15x | High |
| Resource Tuning | Cluster Config | 2x – 5x | Medium |
1. Caching and Persistence
What is Caching in PySpark?
Every time you call an action on a PySpark DataFrame — count(), show(), write() — Spark re-executes the entire DAG from scratch. If your DataFrame is the product of expensive transformations (joins, aggregations, file reads), this means paying that cost repeatedly. Caching breaks this cycle by storing the computed DataFrame in memory (or on disk) so subsequent actions read from the cache rather than recomputing from source.
Storage Levels Explained
PySpark offers several storage levels that control where and how data is cached:
| Storage Level | Memory | Disk | Serialized | Replicated | Best For |
| MEMORY_ONLY | Yes | No | No | No | Small DataFrames, fast access |
| MEMORY_AND_DISK | Yes | Yes (spill) | No | No | Medium DataFrames |
| MEMORY_ONLY_SER | Yes | No | Yes | No | Reducing GC pressure |
| MEMORY_AND_DISK_SER | Yes | Yes | Yes | No | Large DataFrames |
| DISK_ONLY | No | Yes | Yes | No | Very large, infrequent access |
| MEMORY_AND_DISK_2 | Yes | Yes | No | Yes | Fault-tolerant caching |
Caching vs Persisting
cache() is simply a shortcut for persist(StorageLevel.MEMORY_AND_DISK). Use persist() when you need explicit control over the storage level.
▸ When to Use cache()
- DataFrame is used in 2 or more downstream actions or joins
- The DataFrame is expensive to compute (multi-step joins, aggregations)
- It fits comfortably in cluster memory (< 60% of total executor memory)
▸ When NOT to Cache
- DataFrame is used only once — caching adds overhead with no benefit
- DataFrame is very large and doesn’t fit in memory — causes GC pressure
- ETL pipelines that run once — write directly instead of caching
Caching Code Examples
from pyspark import StorageLevel# Simple cache (MEMORY_AND_DISK)df_expensive = spark.read.parquet('/data/raw/events') \ .filter('event_type = "purchase"') \ .join(users_df, 'user_id') \ .groupBy('user_id').agg(F.sum('amount').alias('total'))df_expensive.cache() # lazy — not computed yetdf_expensive.count() # triggers computation + cache fill# Now these reuse the cache:df_expensive.show(10) # no recomputationdf_expensive.write.parquet('/data/output/totals') # no recomputation# Explicit storage leveldf_large.persist(StorageLevel.MEMORY_AND_DISK_SER)# Always unpersist when done to free memorydf_expensive.unpersist()
💡 TIP | Always call an action (like count()) immediately after cache() to trigger computation. Only then is the cache filled. Calling cache() alone does nothing until an action is triggered. |
Checking Cache Status
# Check if a DataFrame is cachedprint(df_expensive.is_cached) # True / False# View cache storage info in Spark UI → Storage tab# or programmatically:spark.sparkContext._jvm.org.apache.spark.storage.StorageUtils \ .rddToDebugString(df_expensive.rdd._jrdd.rdd())
2. Broadcast Join
How Broadcast Join Works in PySpark
A standard Sort-Merge Join between two large tables requires both tables to be shuffled across the network — an enormously expensive operation. A Broadcast Join avoids this entirely by sending a copy of the smaller table to every executor node. Each executor then performs the join locally against its partition of the larger table, with zero network shuffle.
Automatic vs Manual Broadcast
Spark automatically broadcasts tables below the spark.sql.autoBroadcastJoinThreshold (default 10 MB). For larger tables that still qualify for broadcast, you can trigger it manually using broadcast().
▸ Auto-Broadcast Configuration
# Increase auto-broadcast threshold to 50 MBspark.conf.set('spark.sql.autoBroadcastJoinThreshold', 50 * 1024 * 1024)# Disable auto-broadcast (force SMJ for all joins)spark.conf.set('spark.sql.autoBroadcastJoinThreshold', -1)
▸ Manual Broadcast Hint
from pyspark.sql.functions import broadcast# Manually hint Spark to broadcast the smaller tableresult = orders_df.join( broadcast(products_df), # products_df is 200 MB — fits in executor RAM on='product_id', how='inner')# Verify: explain() should show BroadcastHashJoin, NOT SortMergeJoinresult.explain(mode='formatted')
Broadcast Join Limits and Thresholds
| Table Size | Strategy | Config |
| < 10 MB | Auto Broadcast | Default threshold |
| 10 MB – 200 MB | Manual Broadcast | Use broadcast() hint |
| 200 MB – 2 GB | Bucketing | bucketBy() + saveAsTable() |
| > 2 GB | Sort-Merge Join | Default — optimize with AQE |
⚠️ WARN | Broadcasting a table too large for executor memory causes OOM errors. Monitor executor memory usage and stay within 30–40% of executor RAM as a safe limit for broadcast tables. |
3. Smart Partitioning Strategy
Understanding PySpark Partitions
Partitions are the fundamental unit of parallelism in Spark. Every DataFrame is divided into partitions, and each partition is processed by a single task on a single core. Too few partitions means under-utilization of the cluster; too many partitions means excessive task scheduling overhead and tiny files.
The Partition Count Problem
▸ Default Partition Behavior
- On read: — Spark creates 1 partition per HDFS block (~128 MB by default)
- After shuffle: — Spark creates spark.sql.shuffle.partitions partitions (default: 200)
- After repartition(): — Spark creates exactly the number you specify
Repartition vs Coalesce
| Method | Shuffle? | Use Case | Performance |
| repartition(N) | Yes (full shuffle) | Increasing partitions / even distribution | Slower write, faster downstream |
| coalesce(N) | No (narrow) | Decreasing partitions / reducing file count | Faster, but may produce skewed partitions |
| repartitionByRange(N, col) | Yes | Range-partitioned output for sorted reads | Medium — best for ordered data |
▸ Optimal Partition Size Formula
# Rule of thumb: each partition = 128 MB – 512 MB uncompressed# Compute optimal partition count:total_data_bytes = 500 * 1024**3 # 500 GBtarget_partition_bytes = 256 * 1024**2 # 256 MBoptimal_partitions = total_data_bytes // target_partition_bytes# Result: ~2000 partitions# Set before shuffle-heavy operationsspark.conf.set('spark.sql.shuffle.partitions', optimal_partitions)# Repartition a DataFrame explicitlydf_balanced = df.repartition(2000, 'customer_id')# Reduce small files before writingdf_final = df_transformed.coalesce(200)df_final.write.parquet('/data/output/result')
Write Partitioning for Storage Optimization
▸ partitionBy for Hive-style Directory Partitions
# Write with directory-level partitioning (Hive-style)# Creates: /data/events/year=2024/month=01/day=15/part-*.parquetdf.write \ .partitionBy('year', 'month', 'day') \ .mode('overwrite') \ .parquet('/data/events')# Query with partition pruning (only reads matching directories)df_jan = spark.read.parquet('/data/events') \ .filter('year=2024 AND month=01') # Only scans year=2024/month=01/
💡 TIP | Partition on low-cardinality columns (date, region, status). High-cardinality partitioning (user_id) creates millions of tiny directories — use bucketing for those instead. |
4. Adaptive Query Execution (AQE)
What is Adaptive Query Execution?
Traditional Spark planning is static — the query plan is fixed before execution begins, based on estimated statistics that are often wrong. Adaptive Query Execution (AQE), introduced in Spark 3.0, changes this by re-optimizing the query plan at runtime as actual data statistics become available after each shuffle stage.
AQE’s Three Core Features
▸ 1. Dynamic Coalescing of Shuffle Partitions
AQE merges small post-shuffle partitions automatically, reducing the overhead of thousands of tiny tasks without requiring you to manually set spark.sql.shuffle.partitions.
▸ 2. Dynamic Join Strategy Switching
AQE can switch a Sort-Merge Join to a Broadcast Hash Join at runtime if it determines one side is small enough to broadcast — even if the static plan chose SMJ.
▸ 3. Skew Join Optimization
AQE detects skewed partitions and automatically splits them into smaller sub-tasks, distributing the load evenly without manual salting.
Enabling and Configuring AQE
spark = SparkSession.builder \ .config('spark.sql.adaptive.enabled', 'true') # Enable AQE .config('spark.sql.adaptive.coalescePartitions.enabled', 'true') # Auto-merge small partitions .config('spark.sql.adaptive.coalescePartitions.minPartitionNum', '1') # Min after coalesce .config('spark.sql.adaptive.advisoryPartitionSizeInBytes', '128MB') # Target partition size .config('spark.sql.adaptive.skewJoin.enabled', 'true') # Auto skew handling .config('spark.sql.adaptive.skewJoin.skewedPartitionFactor', '5') # 5x median = skewed .config('spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes', '256MB') .getOrCreate()
💡 TIP | AQE is enabled by default in Spark 3.2+. For Spark 3.0–3.1, add spark.sql.adaptive.enabled = true explicitly. It is the single highest-ROI configuration change with the lowest effort. |
5. Predicate Pushdown and Column Pruning
What is Predicate Pushdown?
Predicate pushdown is the practice of applying filter conditions as early as possible in the query execution — ideally at the data source level, before data is loaded into memory. When reading Parquet or ORC files, Spark can pass filter predicates directly to the file reader, which skips entire row groups or files that do not match the condition.
Column Pruning
Column pruning complements predicate pushdown: instead of reading all columns from a file and discarding unused ones, Spark reads only the columns referenced in the query. In columnar formats like Parquet, this can reduce I/O by 80%+.
How to Maximize Pushdown
▸ Apply Filters Early in the Chain
# BAD — filter applied after expensive joinresult = orders_df \ .join(customers_df, 'customer_id') \ .join(products_df, 'product_id') \ .filter("order_date >= '2024-01-01'") # Too late!# GOOD — filter applied before join (Spark Catalyst also does this automatically)result = orders_df \ .filter("order_date >= '2024-01-01'") \ # Pushes down to file scan .join(customers_df, 'customer_id') \ .join(products_df, 'product_id')# BEST — select only needed columns too (column pruning)result = orders_df \ .select('order_id', 'customer_id', 'product_id', 'amount', 'order_date') \ .filter("order_date >= '2024-01-01'") \ .join(customers_df.select('customer_id', 'name'), 'customer_id') \ .join(products_df.select('product_id', 'category'), 'product_id')
▸ Verifying Pushdown in the Physical Plan
# Write data sorted on the filter column to maximize row group skippingdf.sort('order_date', 'amount') \ .write \ .option('parquet.block.size', 128 * 1024 * 1024) \ .parquet('/data/orders_sorted')
Parquet Statistics and Row Group Skipping
Parquet stores min/max statistics for each row group. When a filter condition (e.g., amount > 1000) is outside the min/max range of a row group, Spark skips that row group entirely without reading it. Sorting data before writing dramatically improves this:
# Write data sorted on the filter column to maximize row group skippingdf.sort('order_date', 'amount') \ .write \ .option('parquet.block.size', 128 * 1024 * 1024) \ .parquet('/data/orders_sorted')
6. Kryo Serialization
Why Serialization Matters in Spark
Spark serializes objects when shuffling data across the network, spilling to disk, and caching RDDs. Java’s default serialization is slow and produces large byte arrays. Kryo serialization is typically 10x faster and produces 2–5x smaller output — reducing shuffle I/O, disk spill, and GC pressure.
Enabling Kryo Serialization
spark = SparkSession.builder \ .config('spark.serializer', 'org.apache.spark.serializer.KryoSerializer') \ .config('spark.kryo.registrationRequired', 'false') # Recommended for dev .config('spark.kryo.unsafe', 'true') # Faster but less safe .getOrCreate()# For production: register custom classes to maximize efficiencyspark.conf.set('spark.kryo.classesToRegister', 'com.mycompany.model.Order,com.mycompany.model.Customer')
▸ When Kryo Applies
- RDD operations — full benefit (shuffles, caching, spilling)
- DataFrame/Dataset API — limited benefit (Spark SQL uses its own encoder)
- UDFs that return complex objects — significant benefit
💡 TIP | For pure DataFrame/SQL workloads, Kryo has minimal effect because Spark uses Tungsten’s binary encoding. The biggest gains come from RDD-based code and UDFs returning custom objects. |
7. UDF Optimization
The Hidden Cost of Python UDFs
Python UDFs (User-Defined Functions) are the single biggest performance trap for PySpark developers. When you call a Python UDF, Spark must serialize each row from the JVM to a Python process using Pickle, execute the Python function, then deserialize the result back to the JVM. This round-trip is 5–100x slower than native Spark SQL functions.
UDF Performance Hierarchy
| UDF Type | Execution | Speed | Use When |
| Built-in SQL functions (F.col, F.when, etc.) | JVM native | Fastest | Always prefer — covers 90% of cases |
| Spark SQL expressions (selectExpr) | JVM native | Fastest | String-based transformations |
| Pandas UDF (vectorized) | Arrow + Python | Fast | Complex logic on batches |
| Python UDF (row-by-row) | Python pickle | Slow | Only when no alternative exists |
| RDD map() with Python lambda | Python pickle | Slowest | Avoid in DataFrame workloads |
Replacing Python UDFs with Built-in Functions
▸ Before: Slow Python UDF
from pyspark.sql.functions import udffrom pyspark.sql.types import StringType# Slow Python UDF — serializes every row to Pythonudf(returnType=StringType())def classify_spend(amount): if amount is None: return 'unknown' elif amount < 100: return 'low' elif amount < 1000: return 'medium' else: return 'high'df = df.withColumn('spend_category', classify_spend('amount'))
▸ After: Fast Built-in Functions
from pyspark.sql.functions import when, col# Fast — runs entirely in JVM, zero Python overheaddf = df.withColumn('spend_category', when(col('amount').isNull(), 'unknown') .when(col('amount') < 100, 'low') .when(col('amount') < 1000, 'medium') .otherwise('high'))
When UDFs Are Unavoidable: Use Pandas UDFs
▸ Vectorized Pandas UDF (Arrow-accelerated)
from pyspark.sql.functions import pandas_udffrom pyspark.sql.types import DoubleTypeimport pandas as pd# Pandas UDF operates on entire column vectors via Apache Arrow# ~10-100x faster than row-by-row Python UDFpandas_udf(DoubleType())def compound_interest(principal: pd.Series, rate: pd.Series, years: pd.Series) -> pd.Series: return principal * (1 + rate) ** yearsdf = df.withColumn('future_value', compound_interest('principal', 'interest_rate', 'tenure_years'))
8. Data Skew Handling
Understanding Data Skew in PySpark
Data skew occurs when some partitions contain significantly more data than others — typically because a small number of key values dominate the dataset. A join on a skewed column causes a few tasks to process 100x more data than others, while the rest of the cluster sits idle waiting for the stragglers. This is one of the hardest performance problems to diagnose and fix.
Diagnosing Data Skew
▸ Signs of Skew in Spark UI
- Most tasks complete in seconds — but 1–5 tasks take hours
- Stage completion bar shows a long tail of slow tasks
- One executor’s memory usage is 10x higher than others
▸ Diagnosing with Code
# Check key distribution before joiningkey_dist = df.groupBy('customer_id') \ .count() \ .orderBy('count', ascending=False)key_dist.show(20) # If top 5 rows have millions of rows, you have skew# Statistical profile of key distributionkey_dist.select( F.min('count').alias('min_count'), F.max('count').alias('max_count'), F.mean('count').alias('avg_count'), F.stddev('count').alias('std_count')).show()
Fixing Skew: The Salting Technique
▸ How Salting Works
Salting adds a random suffix to skewed keys, distributing what was one hot partition into N smaller partitions. The join is then performed on the salted key, and results are aggregated back.
from pyspark.sql.functions import col, rand, floor, lit, concat_ws, explode, arraySALT_FACTOR = 20 # Distribute each hot key across 20 buckets# Step 1: Salt the large (skewed) tableorders_salted = orders_df \ .withColumn('salt', (rand() * SALT_FACTOR).cast('int')) \ .withColumn('salted_key', concat_ws('_', col('customer_id'), col('salt')))# Step 2: Explode the small (dimension) table to match all salt valuescustomers_salted = customers_df \ .withColumn('salt_array', array([lit(i) for i in range(SALT_FACTOR)])) \ .withColumn('salt', explode('salt_array')) \ .withColumn('salted_key', concat_ws('_', col('customer_id'), col('salt'))) \ .drop('salt_array', 'salt')# Step 3: Join on salted key — no more hot partitionsresult = orders_salted.join(customers_salted, 'salted_key', 'inner') \ .drop('salt', 'salted_key')
⚠️ WARN | Salting multiplies the size of the smaller table by SALT_FACTOR. For very large dimension tables, use AQE’s automatic skew handling or partial salting (only for skewed keys) instead. |
9. Optimal File Formats and Compression
Choosing the Right File Format
| Format | Type | Splittable? | Schema Evolution | Best For |
| Parquet | Columnar binary | Yes | Yes (add columns) | Analytics, OLAP, most workloads |
| ORC | Columnar binary | Yes | Limited | Hive-centric workloads, Hive ACID |
| Delta Lake | Parquet + tx log | Yes | Full ACID | Lakehouse, upserts, time travel |
| Avro | Row binary | Yes | Full | Streaming, Kafka integration |
| JSON | Row text | No (gzip) | Schema-free | Dev/debug only — avoid in production |
| CSV | Row text | No (gzip) | None | Input ingestion only — never store as CSV |
Compression Codec Comparison
| Codec | Ratio | Speed | Splittable | Recommended Use |
| Snappy | Medium | Very fast | Yes (Parquet/ORC) | Default choice — best balance |
| Zstd | High | Fast | Yes (Parquet/ORC) | Best for storage-constrained envs |
| Gzip | High | Slow | No | Avoid for large files — not splittable |
| LZ4 | Low | Fastest | Yes (Parquet/ORC) | Hot data, low-latency pipelines |
| Uncompressed | None | N/A | Yes | Only for benchmarking |
# Write optimal Parquet with Snappy (default)df.write \ .option('compression', 'snappy') \ .parquet('/data/output/events')# Write with Zstd for better compression (Parquet 2.x, Spark 3.2+)df.write \ .option('compression', 'zstd') \ .option('parquet.compression.codec', 'zstd') \ .parquet('/data/output/events_compressed')
Avoiding the Small Files Problem
▸ Why Small Files Kill Performance
Writing thousands of tiny files (< 10 MB each) causes: slow metadata operations on the namenode, excessive task creation overhead, and poor read performance. Always coalesce before writing:
# Bad: 200 shuffle partitions = 200 tiny filesdf.write.parquet('/data/output')# Good: reduce to ~50 files of ~256 MB eachtarget_file_size_mb = 256estimated_data_mb = df.count() * 500 / (1024 * 1024) # rough estimatenum_files = max(1, int(estimated_data_mb / target_file_size_mb))df.coalesce(num_files).write.parquet('/data/output')
10. Executor and Memory Tuning
Spark Memory Architecture
Each Spark executor has a fixed memory budget divided into three main regions: storage memory (for cached DataFrames), execution memory (for shuffles, sorts, aggregations), and user memory (for your application objects). Tuning these ratios for your workload type dramatically improves throughput and eliminates OOM errors.
Critical Memory Configuration Parameters
| Parameter | Default | Recommended (Batch) | Effect |
| spark.executor.memory | 1g | 8g – 32g | Total executor JVM heap |
| spark.executor.cores | 1 | 4 – 5 | Tasks per executor (4–5 is sweet spot) |
| spark.executor.instances | auto | Cluster dependent | Total number of executors |
| spark.memory.fraction | 0.6 | 0.7 – 0.8 | % of heap for Spark execution+storage |
| spark.memory.storageFraction | 0.5 | 0.3 – 0.5 | % of Spark memory reserved for storage |
| spark.executor.memoryOverhead | 10% | 2g – 4g | Off-heap memory (Python, native libs) |
| spark.driver.memory | 1g | 4g – 8g | Driver heap (larger for collect/toPandas) |
| spark.sql.files.maxPartitionBytes | 128MB | 256MB | Max partition size when reading files |
Production Configuration Template
▸ For Memory-Intensive Analytics Workloads
spark = SparkSession.builder \ .appName('ProductionAnalytics') \ # Executor sizing (assuming 32-core, 256GB nodes) .config('spark.executor.cores', '5') \ .config('spark.executor.memory', '32g') \ .config('spark.executor.memoryOverhead', '4g') \ .config('spark.driver.memory', '8g') \ .config('spark.driver.maxResultSize', '4g') \ # Memory management .config('spark.memory.fraction', '0.75') \ .config('spark.memory.storageFraction', '0.4') \ # Shuffle optimization .config('spark.sql.shuffle.partitions', '400') \ .config('spark.sql.files.maxPartitionBytes', str(256 * 1024 * 1024)) \ # AQE + adaptive tuning .config('spark.sql.adaptive.enabled', 'true') \ .config('spark.sql.adaptive.coalescePartitions.enabled', 'true') \ # Serialization .config('spark.serializer', 'org.apache.spark.serializer.KryoSerializer') \ .getOrCreate()
Off-Heap Memory for Python Workers
▸ PySpark-Specific Memory Considerations
- Each Python worker process uses memory outside the JVM heap
- Set spark.executor.memoryOverhead to at least 2 GB for PySpark jobs
- For Pandas UDFs, add extra overhead per executor (Pandas DataFrames can be large)
- Monitor GC time in Spark UI — if > 5% of task time, increase executor memory
Summary: PySpark Optimization Checklist
Production Optimization Checklist
Use this checklist before promoting any PySpark job to production:
Query & Join Optimization
- Apply filters and column selects before joins (predicate pushdown)
- Use broadcast() for dimension tables < 200 MB
- Verify shuffle elimination with explain() for bucketed joins
- Enable AQE: spark.sql.adaptive.enabled = true
Data Layout & I/O
- Write data in Parquet or ORC format with Snappy or Zstd compression
- Use partitionBy for time-series data with date/region filters
- Coalesce before writing to avoid small files (target 128–512 MB per file)
- Sort data on filter columns before writing for row group skipping
Memory & Resources
- Set spark.executor.cores to 4–5 per executor
- Set spark.executor.memoryOverhead to at least 2g for PySpark
- Tune spark.sql.shuffle.partitions to match your data volume
- Cache DataFrames reused in 2+ actions; always unpersist() when done
Code Quality
- Replace Python UDFs with built-in SQL functions wherever possible
- Use Pandas UDFs (vectorized) when Python logic is unavoidable
- Diagnose and salt skewed join keys before deploying at scale
- Enable Kryo serialization for RDD-heavy workloads
Final Thoughts
PySpark optimization is not a one-time activity — it is an ongoing discipline. The techniques in this guide range from zero-config wins (enabling AQE, choosing Parquet) to architectural decisions (bucketing, partitioning strategy) that shape how your entire data platform performs.
Start with the quick wins: AQE, broadcast joins, and predicate pushdown. Then layer in the architectural optimizations as your pipelines mature. Monitor the Spark UI obsessively — it tells you everything about where your jobs are spending time.
🎯 GOAL | The best-optimized Spark job is the one that reads the least data, shuffles the least data, and uses exactly as much memory as it needs — no more, no less. |
Discover more from DataSangyan
Subscribe to get the latest posts sent to your email.