Skip to content

Commit 41b34d6

Browse files
liviazhuanishshri-db
authored andcommitted
[SPARK-55493][SS] Do not mkdirs in streaming checkpoint state directory in StateDataSource
### What changes were proposed in this pull request? Previously, we try to create a new directory for the state directory in the checkpoint directory if it doesn't exist when running `StateDataSource`. This change creates new readOnly mode so that datasources do not need to mkdirs. ### Why are the changes needed? Allow usage of StateDataSource on checkpoints that are read-only. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit tests ### Was this patch authored or co-authored using generative AI tooling? Generated-by: claude opus 4.6 Closes #54277 from liviazhu/liviazhu-db/stds-fix. Authored-by: Livia Zhu <livia.zhu@databricks.com> Signed-off-by: Anish Shrigondekar <anish.shrigondekar@databricks.com>
1 parent 6976ae7 commit 41b34d6

File tree

7 files changed

+216
-16
lines changed

7 files changed

+216
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
379379
partitionId, sourceOptions.storeName)
380380
val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())
381381
val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf,
382-
oldSchemaFilePaths = oldSchemaFilePaths)
382+
oldSchemaFilePaths = oldSchemaFilePaths, createSchemaDir = false)
383383
val stateSchema = manager.readSchemaFile()
384384

385385
if (sourceOptions.internalOnlyReadAllColumnFamilies) {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,13 @@ object StreamStreamJoinStateHelper {
9494

9595
// read the key schema from the keyToNumValues store for the join keys
9696
val manager = new StateSchemaCompatibilityChecker(
97-
providerIdForKeyToNumValues, newHadoopConf, oldSchemaFilePaths)
97+
providerIdForKeyToNumValues, newHadoopConf, oldSchemaFilePaths,
98+
createSchemaDir = false)
9899
val kSchema = manager.readSchemaFile().head.keySchema
99100

100101
// read the value schema from the keyWithIndexToValue store for the values
101102
val manager2 = new StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue,
102-
newHadoopConf, oldSchemaFilePaths)
103+
newHadoopConf, oldSchemaFilePaths, createSchemaDir = false)
103104
val vSchema = manager2.readSchemaFile().head.valueSchema
104105

105106
(kSchema, vSchema)
@@ -109,7 +110,7 @@ object StreamStreamJoinStateHelper {
109110
val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())
110111

111112
val manager = new StateSchemaCompatibilityChecker(
112-
providerId, newHadoopConf, oldSchemaFilePaths)
113+
providerId, newHadoopConf, oldSchemaFilePaths, createSchemaDir = false)
113114
val kSchema = manager.readSchemaFile().find { schema =>
114115
schema.colFamilyName == storeNames(0)
115116
}.map(_.keySchema).get

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2.state.metadata
1818

19+
import java.io.FileNotFoundException
1920
import java.util
2021

2122
import scala.jdk.CollectionConverters._
@@ -222,7 +223,8 @@ class StateMetadataPartitionReader(
222223
1
223224
}
224225
OperatorStateMetadataReader.createReader(
225-
operatorIdPath, hadoopConf, operatorStateMetadataVersion, batchId).read() match {
226+
operatorIdPath, hadoopConf, operatorStateMetadataVersion, batchId,
227+
createMetadataDir = false).read() match {
226228
case Some(metadata) => metadata
227229
case None => throw StateDataSourceErrors.failedToReadOperatorMetadata(checkpointLocation,
228230
batchId)
@@ -231,7 +233,7 @@ class StateMetadataPartitionReader(
231233
} catch {
232234
// if the operator metadata is not present, catch the exception
233235
// and return an empty array
234-
case ex: Exception =>
236+
case ex: FileNotFoundException =>
235237
logWarning(log"Failed to find operator metadata for " +
236238
log"path=${MDC(LogKeys.CHECKPOINT_LOCATION, checkpointLocation)} " +
237239
log"with exception=${MDC(LogKeys.EXCEPTION, ex)}")

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,14 @@ object OperatorStateMetadataReader {
199199
stateCheckpointPath: Path,
200200
hadoopConf: Configuration,
201201
version: Int,
202-
batchId: Long): OperatorStateMetadataReader = {
202+
batchId: Long,
203+
createMetadataDir: Boolean = true): OperatorStateMetadataReader = {
203204
version match {
204205
case 1 =>
205206
new OperatorStateMetadataV1Reader(stateCheckpointPath, hadoopConf)
206207
case 2 =>
207-
new OperatorStateMetadataV2Reader(stateCheckpointPath, hadoopConf, batchId)
208+
new OperatorStateMetadataV2Reader(stateCheckpointPath, hadoopConf, batchId,
209+
createMetadataDir)
208210
case _ =>
209211
throw new IllegalArgumentException(s"Failed to create reader for operator metadata " +
210212
s"with version=$version")
@@ -319,7 +321,8 @@ class OperatorStateMetadataV2Writer(
319321
class OperatorStateMetadataV2Reader(
320322
stateCheckpointPath: Path,
321323
hadoopConf: Configuration,
322-
batchId: Long) extends OperatorStateMetadataReader {
324+
batchId: Long,
325+
createMetadataDir: Boolean = true) extends OperatorStateMetadataReader with Logging {
323326

324327
// Check that the requested batchId is available in the checkpoint directory
325328
val baseCheckpointDir = stateCheckpointPath.getParent.getParent
@@ -331,7 +334,12 @@ class OperatorStateMetadataV2Reader(
331334
private val metadataDirPath = OperatorStateMetadataV2.metadataDirPath(stateCheckpointPath)
332335
private lazy val fm = CheckpointFileManager.create(metadataDirPath, hadoopConf)
333336

334-
fm.mkdirs(metadataDirPath.getParent)
337+
if (createMetadataDir && !fm.exists(metadataDirPath.getParent)) {
338+
fm.mkdirs(metadataDirPath.getParent)
339+
} else if (!createMetadataDir) {
340+
logInfo(log"Skipping metadata directory creation (createMetadataDir=false) " +
341+
log"at ${MDC(LogKeys.CHECKPOINT_LOCATION, baseCheckpointDir.toString)}")
342+
}
335343

336344
override def version: Int = 2
337345

@@ -352,7 +360,8 @@ class OperatorStateMetadataV2Reader(
352360

353361
// List the available batches in the operator metadata directory
354362
private def listOperatorMetadataBatches(): Array[Long] = {
355-
if (!fm.exists(metadataDirPath)) {
363+
// If parent doesn't exist, return empty array rather than throwing an exception
364+
if (!fm.exists(metadataDirPath.getParent) || !fm.exists(metadataDirPath)) {
356365
return Array.empty
357366
}
358367

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ class StateSchemaCompatibilityChecker(
8181
providerId: StateStoreProviderId,
8282
hadoopConf: Configuration,
8383
oldSchemaFilePaths: List[Path] = List.empty,
84-
newSchemaFilePath: Option[Path] = None) extends Logging {
84+
newSchemaFilePath: Option[Path] = None,
85+
createSchemaDir: Boolean = true) extends Logging {
8586

8687
// For OperatorStateMetadataV1: Only one schema file present per operator
8788
// per query
@@ -96,7 +97,12 @@ class StateSchemaCompatibilityChecker(
9697

9798
private val fm = CheckpointFileManager.create(schemaFileLocation, hadoopConf)
9899

99-
fm.mkdirs(schemaFileLocation.getParent)
100+
if (createSchemaDir && !fm.exists(schemaFileLocation.getParent)) {
101+
fm.mkdirs(schemaFileLocation.getParent)
102+
} else if (!createSchemaDir) {
103+
logInfo(log"Skipping schema directory creation (createSchemaDir=false) " +
104+
log"at ${MDC(LogKeys.CHECKPOINT_LOCATION, schemaFileLocation.getParent.toString)}")
105+
}
100106

101107
private val conf = SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(new SQLConf())
102108

@@ -112,7 +118,7 @@ class StateSchemaCompatibilityChecker(
112118
def readSchemaFiles(): Map[String, List[StateStoreColFamilySchema]] = {
113119
val stateSchemaFilePaths = (oldSchemaFilePaths ++ List(schemaFileLocation)).distinct
114120
stateSchemaFilePaths.flatMap { schemaFile =>
115-
if (fm.exists(schemaFile)) {
121+
if (fm.exists(schemaFile.getParent) && fm.exists(schemaFile)) {
116122
val inStream = fm.open(schemaFile)
117123
StateSchemaCompatibilityChecker.readSchemaFile(inStream)
118124
} else {
@@ -163,7 +169,7 @@ class StateSchemaCompatibilityChecker(
163169
* otherwise
164170
*/
165171
private def getExistingKeyAndValueSchema(): List[StateStoreColFamilySchema] = {
166-
if (fm.exists(schemaFileLocation)) {
172+
if (fm.exists(schemaFileLocation.getParent) && fm.exists(schemaFileLocation)) {
167173
readSchemaFile()
168174
} else {
169175
List.empty

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.functions.col
3636
import org.apache.spark.sql.internal.SQLConf
3737
import org.apache.spark.sql.streaming.{OutputMode, TimeMode, TransformWithStateSuiteUtils}
3838
import org.apache.spark.sql.types.{IntegerType, StructType}
39+
import org.apache.spark.util.Utils
3940

4041
class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase {
4142
import testImplicits._
@@ -1501,3 +1502,159 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
15011502
}
15021503
}
15031504
}
1505+
1506+
/**
1507+
* Test suite that verifies the state data source reader does not create empty state
1508+
* directories when reading state for all stateful operators.
1509+
*/
1510+
class StateDataSourceNoEmptyDirCreationSuite extends StateDataSourceTestBase {
1511+
1512+
/**
1513+
* Asserts that the cause chain of the given exception contains
1514+
* an instance of the expected type.
1515+
*/
1516+
private def assertCauseChainContains(
1517+
e: Throwable,
1518+
expectedType: Class[_ <: Throwable]): Unit = {
1519+
var current: Throwable = e
1520+
while (current != null) {
1521+
if (expectedType.isInstance(current)) return
1522+
current = current.getCause
1523+
}
1524+
fail(
1525+
s"Expected ${expectedType.getSimpleName} in cause chain, " +
1526+
s"but got: ${e.getClass.getSimpleName}: ${e.getMessage}")
1527+
}
1528+
1529+
/**
1530+
* Runs a stateful query to create the checkpoint structure, deletes the state directory,
1531+
* then attempts to read via the state data source and verifies that the state directory
1532+
* is not recreated.
1533+
*
1534+
* @param runQuery function that runs one batch of a stateful query given a checkpoint path
1535+
* @param readState function that attempts to read state given a checkpoint path
1536+
* @param expectedCause the exception type expected in the cause chain
1537+
*/
1538+
private def assertStateDirectoryNotRecreatedOnRead(
1539+
runQuery: String => Unit,
1540+
readState: String => Unit,
1541+
expectedCause: Class[_ <: Throwable] =
1542+
classOf[StateDataSourceReadStateSchemaFailure]): Unit = {
1543+
withTempDir { tempDir =>
1544+
val checkpointPath = tempDir.getAbsolutePath
1545+
1546+
// Step 1: Run the stateful query to create the full checkpoint structure
1547+
runQuery(checkpointPath)
1548+
1549+
// Step 2: Delete the state directory
1550+
val stateDir = new File(tempDir, "state")
1551+
assert(stateDir.exists(), "State directory should exist after running the query")
1552+
Utils.deleteRecursively(stateDir)
1553+
assert(!stateDir.exists(), "State directory should be deleted")
1554+
1555+
// Step 3: Attempt to read state - expected to fail since state is deleted
1556+
val e = intercept[Exception] {
1557+
readState(checkpointPath)
1558+
}
1559+
assertCauseChainContains(e, expectedCause)
1560+
1561+
// Step 4: Verify the state directory was NOT recreated by the reader
1562+
assert(!stateDir.exists(),
1563+
"State data source reader should not recreate the deleted state directory")
1564+
}
1565+
}
1566+
1567+
test("streaming aggregation: no empty state dir created on read") {
1568+
assertStateDirectoryNotRecreatedOnRead(
1569+
runQuery = checkpointPath => {
1570+
runLargeDataStreamingAggregationQuery(checkpointPath)
1571+
},
1572+
readState = checkpointPath => {
1573+
spark.read
1574+
.format("statestore")
1575+
.option(StateSourceOptions.PATH, checkpointPath)
1576+
.load()
1577+
.collect()
1578+
}
1579+
)
1580+
}
1581+
1582+
test("drop duplicates: no empty state dir created on read") {
1583+
assertStateDirectoryNotRecreatedOnRead(
1584+
runQuery = checkpointPath => {
1585+
runDropDuplicatesQuery(checkpointPath)
1586+
},
1587+
readState = checkpointPath => {
1588+
spark.read
1589+
.format("statestore")
1590+
.option(StateSourceOptions.PATH, checkpointPath)
1591+
.load()
1592+
.collect()
1593+
}
1594+
)
1595+
}
1596+
1597+
test("flatMapGroupsWithState: no empty state dir created on read") {
1598+
assertStateDirectoryNotRecreatedOnRead(
1599+
runQuery = checkpointPath => {
1600+
runFlatMapGroupsWithStateQuery(checkpointPath)
1601+
},
1602+
readState = checkpointPath => {
1603+
spark.read
1604+
.format("statestore")
1605+
.option(StateSourceOptions.PATH, checkpointPath)
1606+
.load()
1607+
.collect()
1608+
}
1609+
)
1610+
}
1611+
1612+
test("stream-stream join: no empty state dir created on read") {
1613+
assertStateDirectoryNotRecreatedOnRead(
1614+
runQuery = checkpointPath => {
1615+
runStreamStreamJoinQuery(checkpointPath)
1616+
},
1617+
readState = checkpointPath => {
1618+
spark.read
1619+
.format("statestore")
1620+
.option(StateSourceOptions.PATH, checkpointPath)
1621+
.option(StateSourceOptions.JOIN_SIDE, "left")
1622+
.load()
1623+
.collect()
1624+
}
1625+
)
1626+
}
1627+
1628+
test("transformWithState: no empty state dir created on read") {
1629+
assertStateDirectoryNotRecreatedOnRead(
1630+
runQuery = checkpointPath => {
1631+
runTransformWithStateQuery(checkpointPath)
1632+
},
1633+
readState = checkpointPath => {
1634+
spark.read
1635+
.format("statestore")
1636+
.option(StateSourceOptions.PATH, checkpointPath)
1637+
.option(StateSourceOptions.STATE_VAR_NAME, "countState")
1638+
.load()
1639+
.collect()
1640+
},
1641+
expectedCause = classOf[IllegalArgumentException]
1642+
)
1643+
}
1644+
1645+
test("session window aggregation: no empty state dir created on read") {
1646+
assertStateDirectoryNotRecreatedOnRead(
1647+
runQuery = checkpointPath => {
1648+
runSessionWindowAggregationQuery(checkpointPath)
1649+
},
1650+
readState = checkpointPath => {
1651+
spark.read
1652+
.format("statestore")
1653+
.option(StateSourceOptions.PATH, checkpointPath)
1654+
.load()
1655+
.collect()
1656+
}
1657+
)
1658+
}
1659+
1660+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
2222
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
2323
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2424
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
25-
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore}
25+
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StateStore}
2626
import org.apache.spark.sql.functions._
2727
import org.apache.spark.sql.internal.SQLConf
2828
import org.apache.spark.sql.streaming._
@@ -445,6 +445,31 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest {
445445
)
446446
}
447447

448+
/**
449+
* Runs one batch of a transformWithState query (using RunningCountStatefulProcessor)
450+
* to create checkpoint structure with state. Uses RocksDBStateStoreProvider.
451+
*/
452+
protected def runTransformWithStateQuery(checkpointRoot: String): Unit = {
453+
withSQLConf(
454+
SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName,
455+
SQLConf.SHUFFLE_PARTITIONS.key -> TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString
456+
) {
457+
val inputData = MemoryStream[String]
458+
val result = inputData.toDS()
459+
.groupByKey(x => x)
460+
.transformWithState(new org.apache.spark.sql.streaming.RunningCountStatefulProcessor(),
461+
TimeMode.None(),
462+
OutputMode.Update())
463+
464+
testStream(result, OutputMode.Update())(
465+
StartStream(checkpointLocation = checkpointRoot),
466+
AddData(inputData, "a"),
467+
CheckNewAnswer(("a", "1")),
468+
StopStream
469+
)
470+
}
471+
}
472+
448473
/**
449474
* Helper function to create a query that combines deduplication and aggregation.
450475
* This creates a more complex query with multiple stateful operators:

0 commit comments

Comments
 (0)