Skip to content

Commit a143f61

Browse files
authored
feat: optimizing partition compute (#77)
* BugFix: Fixed a bug in collecting statistics in parallel at the partition level * feat: add jdbc.partition-compute-parallelism option * Optimize partition computation parallelism logic and documentation - Optimize partition computation parallelism logic in OBMySQLPartition to dynamically adjust based on partition count and user configuration - Replace Chinese comments with English comments - Update documentation for jdbc.partition-compute-parallelism parameter to highlight driver node execution and performance tuning recommendations
1 parent ffefa2e commit a143f61

File tree

5 files changed

+106
-44
lines changed

5 files changed

+106
-44
lines changed

docs/spark-catalog-oceanbase.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,13 @@ Precautions for direct-load:
352352
<td>Int</td>
353353
<td>Controls the parallelism level for statistical queries (e.g., COUNT, MIN, MAX) by adding /*+ PARALLEL(N) */ hint to generated SQL.</td>
354354
</tr>
355+
<tr>
356+
<td>spark.sql.catalog.your_catalog_name.jdbc.partition-compute-parallelism</td>
357+
<td>No</td>
358+
<td style="word-wrap: break-word;">32</td>
359+
<td>Int</td>
360+
<td>Controls the parallelism level for partition computation. This parameter determines the number of threads used when computing partitions for partitioned tables (mainly through parallel SQL queries to OceanBase partition statistics). The computation task runs on the driver node. Higher values can improve performance for tables with many partitions. When setting a larger value for this parameter, appropriately increasing the CPU cores and memory of the driver node can achieve better performance.</td>
361+
</tr>
355362
<tr>
356363
<td>spark.sql.catalog.your_catalog_name.jdbc.query-timeout-hint-degree</td>
357364
<td>否</td>

docs/spark-catalog-oceanbase_cn.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ select * from spark_catalog.default.orders;
348348
<td>Int</td>
349349
<td>通过向生成的 SQL 添加 /*+ PARALLEL(N) */ hint 来控制统计查询(例如 COUNT、MIN、MAX)的并行级别。</td>
350350
</tr>
351+
<tr>
352+
<td>spark.sql.catalog.your_catalog_name.jdbc.partition-compute-parallelism</td>
353+
<td>否</td>
354+
<td style="word-wrap: break-word;">32</td>
355+
<td>Int</td>
356+
<td>控制分区计算的并行级别。此参数确定计算分区表分区时使用的线程数,主要通过并行 SQL 查询 OceanBase 分区统计信息来实现。该计算任务在 driver 节点运行,对于分区数量较多的表,设置更高的值可以显著提升性能。当指定的该参数值较大的时候,适当调大 driver 节点的 CPU 核数和内存,可以取得更好的性能。</td>
357+
</tr>
351358
<tr>
352359
<td>spark.sql.catalog.your_catalog_name.jdbc.query-timeout-hint-degree</td>
353360
<td>否</td>

spark-connector-oceanbase/spark-connector-oceanbase-3.1/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcReader.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ class OBJdbcReader(
5151
private lazy val stmt: PreparedStatement =
5252
conn.prepareStatement(buildQuerySql(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
5353
private lazy val rs: ResultSet = {
54+
partition match {
55+
case part: OBMySQLPartition =>
56+
part.unevenlyWhereValue.zipWithIndex.foreach {
57+
case (value, index) => stmt.setObject(index + 1, value)
58+
}
59+
case _ =>
60+
}
5461
stmt.setFetchSize(config.getJdbcFetchSize)
5562
stmt.setQueryTimeout(config.getJdbcQueryTimeout)
5663
stmt.executeQuery()
@@ -96,16 +103,23 @@ class OBJdbcReader(
96103
.map(p => s"($p)")
97104
.mkString(" AND ")
98105

99-
val whereClause: String = {
100-
if (filterWhereClause.nonEmpty) {
106+
val part: OBMySQLPartition = partition.asInstanceOf[OBMySQLPartition]
107+
val whereClause = {
108+
if (part.whereClause != null && filterWhereClause.nonEmpty) {
109+
"WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
110+
} else if (part.whereClause != null) {
111+
"WHERE " + part.whereClause
112+
} else if (filterWhereClause.nonEmpty) {
101113
"WHERE " + filterWhereClause
102114
} else {
103115
""
104116
}
105117
}
106-
val part: OBMySQLPartition = partition.asInstanceOf[OBMySQLPartition]
118+
val hint =
119+
s"/*+ PARALLEL(${config.getJdbcParallelHintDegree}) */"
120+
107121
s"""
108-
|SELECT $columnStr FROM ${config.getDbTable} ${part.partitionClause}
122+
|SELECT $hint $columnStr FROM ${config.getDbTable} ${part.partitionClause}
109123
|$whereClause ${part.limitOffsetClause}
110124
|""".stripMargin
111125
}

spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/java/com/oceanbase/spark/config/OceanBaseConfig.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ public class OceanBaseConfig extends Config implements Serializable {
273273
.intConf()
274274
.createWithDefault(4);
275275

276+
public static final ConfigEntry<Integer> JDBC_PARTITION_COMPUTE_PARALLELISM =
277+
new ConfigBuilder("jdbc.partition-compute-parallelism")
278+
.doc(
279+
"Controls the parallelism level for partition computation. This parameter determines the number of threads used when computing partitions for partitioned tables. Higher values can improve performance for tables with many partitions.")
280+
.version(ConfigConstants.VERSION_1_3_0)
281+
.intConf()
282+
.createWithDefault(32);
283+
276284
public static final ConfigEntry<Long> JDBC_MAX_RECORDS_PER_PARTITION =
277285
new ConfigBuilder("jdbc.max-records-per-partition")
278286
.doc(
@@ -533,6 +541,10 @@ public Integer getJdbcStatsParallelHintDegree() {
533541
return get(JDBC_STATISTICS_PARALLEL_HINT_DEGREE);
534542
}
535543

544+
public Integer getJdbcPartitionComputeParallelism() {
545+
return get(JDBC_PARTITION_COMPUTE_PARALLELISM);
546+
}
547+
536548
public Optional<Long> getJdbcMaxRecordsPrePartition() {
537549
return Optional.ofNullable(get(JDBC_MAX_RECORDS_PER_PARTITION));
538550
}

spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBMySQLPartition.scala

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ import org.apache.spark.sql.connector.read.InputPartition
2525

2626
import java.sql.Connection
2727
import java.util.{Objects, Optional}
28+
import java.util.concurrent.{Executors, TimeUnit}
2829
import java.util.concurrent.TimeUnit
2930

3031
import scala.collection.mutable
3132
import scala.collection.mutable.ArrayBuffer
32-
import scala.concurrent.{Await, Future}
33-
import scala.concurrent.ExecutionContext.Implicits.global
33+
import scala.concurrent.{Await, ExecutionContext, Future}
3434
import scala.concurrent.duration.Duration
3535

3636
/** Data corresponding to one partition of a JDBCLimitRDD. */
@@ -147,7 +147,7 @@ object OBMySQLPartition extends Logging {
147147
computeUnevenlyWherePartInfoForNonPartTable(connection, config, priKeyColumnName)
148148
} else {
149149
// For partition table
150-
computeUnevenlyWherePartInfoForPartTable(connection, config, obPartInfos, priKeyColumnName)
150+
computeUnevenlyWherePartInfoForPartTable(config, obPartInfos, priKeyColumnName)
151151
}
152152
}
153153

@@ -434,48 +434,70 @@ object OBMySQLPartition extends Logging {
434434
}
435435

436436
private def computeUnevenlyWherePartInfoForPartTable(
437-
conn: Connection,
438437
config: OceanBaseConfig,
439438
obPartInfos: Array[OBPartInfo],
440439
priKeyColumnName: String): Array[InputPartition] = {
441440
val startTime = System.nanoTime()
442-
val futures = obPartInfos.map(
443-
obPartInfo => {
444-
Future {
445-
val partitionName = obPartInfo.subPartName match {
446-
case x if Objects.isNull(x) => PARTITION_QUERY_FORMAT.format(obPartInfo.partName)
447-
case _ => PARTITION_QUERY_FORMAT.format(obPartInfo.subPartName)
448-
}
449-
val unevenlyPriKeyTableInfo =
450-
obtainUnevenlyPriKeyTableInfo(conn, config, partitionName, priKeyColumnName)
451-
val partitions =
452-
computeUnevenlyWhereSparkPart(
453-
conn,
454-
unevenlyPriKeyTableInfo,
455-
partitionName,
456-
priKeyColumnName,
457-
config)
458-
partitions
459-
}
460-
})
461-
val arr = futures.flatMap(
462-
future => {
463-
Await.result(future, Duration(10, TimeUnit.MINUTES))
464-
})
465-
val endTime = System.nanoTime()
466-
logInfo(s"Time cost: ${(endTime - startTime) / 1000000} ms")
467441

468-
arr.zipWithIndex.map {
469-
case (partInfo, index) =>
470-
OBMySQLPartition(
471-
partInfo.partitionClause,
472-
limitOffsetClause = EMPTY_STRING,
473-
whereClause = partInfo.whereClause,
474-
useHiddenPKColumn = partInfo.useHiddenPKColumn,
475-
unevenlyWhereValue = partInfo.unevenlyWhereValue,
476-
idx = index
477-
)
478-
}.toArray
442+
// Create custom thread pool with optimized parallelism
443+
val maxParallelism = config.getJdbcPartitionComputeParallelism
444+
val partitionCount = obPartInfos.length
445+
val parallelism = Math.min(partitionCount, maxParallelism)
446+
val executor = Executors.newFixedThreadPool(parallelism)
447+
val executionContext = ExecutionContext.fromExecutor(executor)
448+
449+
try {
450+
val futures = obPartInfos.map(
451+
obPartInfo => {
452+
Future {
453+
val conn = OBJdbcUtils.getConnection(config)
454+
try {
455+
val partitionName = obPartInfo.subPartName match {
456+
case x if Objects.isNull(x) => PARTITION_QUERY_FORMAT.format(obPartInfo.partName)
457+
case _ => PARTITION_QUERY_FORMAT.format(obPartInfo.subPartName)
458+
}
459+
val unevenlyPriKeyTableInfo =
460+
obtainUnevenlyPriKeyTableInfo(conn, config, partitionName, priKeyColumnName)
461+
val partitions =
462+
computeUnevenlyWhereSparkPart(
463+
conn,
464+
unevenlyPriKeyTableInfo,
465+
partitionName,
466+
priKeyColumnName,
467+
config)
468+
partitions
469+
} finally {
470+
conn.close()
471+
}
472+
}(executionContext)
473+
})
474+
val arr = futures.flatMap(
475+
future => {
476+
Await.result(future, Duration(10, TimeUnit.MINUTES))
477+
})
478+
val endTime = System.nanoTime()
479+
logInfo(
480+
s"Partition computation completed with parallelism=$parallelism, time cost: ${(endTime - startTime) / 1000000} ms")
481+
482+
arr.zipWithIndex.map {
483+
case (partInfo, index) =>
484+
OBMySQLPartition(
485+
partInfo.partitionClause,
486+
limitOffsetClause = EMPTY_STRING,
487+
whereClause = partInfo.whereClause,
488+
useHiddenPKColumn = partInfo.useHiddenPKColumn,
489+
unevenlyWhereValue = partInfo.unevenlyWhereValue,
490+
idx = index
491+
)
492+
}.toArray
493+
} finally {
494+
// Shutdown thread pool
495+
executor.shutdown()
496+
if (!executor.awaitTermination(30, TimeUnit.SECONDS)) {
497+
executor.shutdownNow()
498+
logWarning("Thread pool did not terminate gracefully, forcing shutdown")
499+
}
500+
}
479501
}
480502

481503
/**

0 commit comments

Comments
 (0)