Skip to content

Commit 84e8f1f

Browse files
authored
Add boilerplate for graph broadcasting (#20715)
Ref #933 After we will have new parser, workers will be able to avoid loading the graph, instead we will broadcast it from coordinator. This PR adds necessary boilerplate for this switch, so that we can do it as soon as new parser is ready.
1 parent aa52192 commit 84e8f1f

File tree

2 files changed

+131
-1
lines changed

2 files changed

+131
-1
lines changed

mypy/build.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,21 @@
5656
from mypy.cache import (
5757
CACHE_VERSION,
5858
DICT_STR_GEN,
59+
LIST_GEN,
5960
LITERAL_NONE,
6061
CacheMeta,
6162
ReadBuffer,
6263
SerializedError,
6364
Tag,
6465
WriteBuffer,
66+
read_bytes,
6567
read_int,
6668
read_int_list,
6769
read_int_opt,
6870
read_str,
6971
read_str_list,
7072
read_str_opt,
73+
write_bytes,
7174
write_int,
7275
write_int_list,
7376
write_int_opt,
@@ -2391,6 +2394,93 @@ def __init__(
23912394
self.add_ancestors()
23922395
self.size_hint = size_hint
23932396

2397+
def write(self, buf: WriteBuffer) -> None:
2398+
"""Serialize State for sending to build worker.
2399+
2400+
Note that unlike write() methods for most other classes, this one is
2401+
not idempotent. We erase some bulky values that should either be not needed
2402+
for processing by the worker, or can be re-created from other data relatively
2403+
quickly. These are:
2404+
* self.meta: workers will call self.reload_meta() anyway.
2405+
* self.options: can be restored with Options.clone_for_module().
2406+
* self.error_lines: fresh errors are handled by the coordinator.
2407+
"""
2408+
write_int(buf, self.order)
2409+
write_str(buf, self.id)
2410+
write_str_opt(buf, self.path)
2411+
write_str_opt(buf, self.source) # mostly for mypy -c '<some code>'
2412+
write_bool(buf, self.ignore_all)
2413+
write_int(buf, self.caller_line)
2414+
write_tag(buf, LIST_GEN)
2415+
write_int_bare(buf, len(self.import_context))
2416+
for path, line in self.import_context:
2417+
write_str(buf, path)
2418+
write_int(buf, line)
2419+
write_bytes(buf, self.interface_hash)
2420+
write_str_opt(buf, self.meta_source_hash)
2421+
write_str_list(buf, self.dependencies)
2422+
write_str_list(buf, self.suppressed)
2423+
# TODO: we can possibly serialize these dictionaries in a more compact way.
2424+
# Most keys in the dictionaries should be the same, so we can write them once.
2425+
write_tag(buf, DICT_STR_GEN)
2426+
write_int_bare(buf, len(self.priorities))
2427+
for mod_id, prio in self.priorities.items():
2428+
write_str_bare(buf, mod_id)
2429+
write_int(buf, prio)
2430+
write_tag(buf, DICT_STR_GEN)
2431+
write_int_bare(buf, len(self.dep_line_map))
2432+
for mod_id, line in self.dep_line_map.items():
2433+
write_str_bare(buf, mod_id)
2434+
write_int(buf, line)
2435+
write_tag(buf, DICT_STR_GEN)
2436+
write_int_bare(buf, len(self.dep_hashes))
2437+
for mod_id, dep_hash in self.dep_hashes.items():
2438+
write_str_bare(buf, mod_id)
2439+
write_bytes(buf, dep_hash)
2440+
write_int(buf, self.size_hint)
2441+
2442+
@classmethod
2443+
def read(cls, buf: ReadBuffer, manager: BuildManager) -> State:
2444+
order = read_int(buf)
2445+
id = read_str(buf)
2446+
path = read_str_opt(buf)
2447+
source = read_str_opt(buf)
2448+
ignore_all = read_bool(buf)
2449+
caller_line = read_int(buf)
2450+
assert read_tag(buf) == LIST_GEN
2451+
import_context = [(read_str(buf), read_int(buf)) for _ in range(read_int_bare(buf))]
2452+
interface_hash = read_bytes(buf)
2453+
meta_source_hash = read_str_opt(buf)
2454+
dependencies = read_str_list(buf)
2455+
suppressed = read_str_list(buf)
2456+
assert read_tag(buf) == DICT_STR_GEN
2457+
priorities = {read_str_bare(buf): read_int(buf) for _ in range(read_int_bare(buf))}
2458+
assert read_tag(buf) == DICT_STR_GEN
2459+
dep_line_map = {read_str_bare(buf): read_int(buf) for _ in range(read_int_bare(buf))}
2460+
assert read_tag(buf) == DICT_STR_GEN
2461+
dep_hashes = {read_str_bare(buf): read_bytes(buf) for _ in range(read_int_bare(buf))}
2462+
return cls(
2463+
manager=manager,
2464+
order=order,
2465+
id=id,
2466+
path=path,
2467+
source=source,
2468+
options=manager.options.clone_for_module(id),
2469+
ignore_all=ignore_all,
2470+
caller_line=caller_line,
2471+
import_context=import_context,
2472+
meta=None,
2473+
interface_hash=interface_hash,
2474+
meta_source_hash=meta_source_hash,
2475+
dependencies=dependencies,
2476+
suppressed=suppressed,
2477+
priorities=priorities,
2478+
dep_line_map=dep_line_map,
2479+
dep_hashes=dep_hashes,
2480+
error_lines=[],
2481+
size_hint=read_int(buf),
2482+
)
2483+
23942484
def reload_meta(self) -> None:
23952485
"""Force reload of cache meta.
23962486
@@ -3727,11 +3817,19 @@ def find_stale_sccs(
37273817

37283818
def process_graph(graph: Graph, manager: BuildManager) -> None:
37293819
"""Process everything in dependency order."""
3820+
# Broadcast graph to workers before computing SCCs to save a bit of time.
3821+
graph_message = GraphMessage(graph=graph)
3822+
buf = WriteBuffer()
3823+
graph_message.write(buf)
3824+
graph_data = buf.getvalue()
3825+
for worker in manager.workers:
3826+
AckMessage.read(receive(worker.conn))
3827+
worker.conn.write_bytes(graph_data)
3828+
37303829
sccs = sorted_components(graph)
37313830
manager.log(
37323831
"Found %d SCCs; largest has %d nodes" % (len(sccs), max(len(scc.mod_ids) for scc in sccs))
37333832
)
3734-
37353833
scc_by_id = {scc.id: scc for scc in sccs}
37363834
manager.scc_by_id = scc_by_id
37373835
manager.top_order = [scc.id for scc in sccs]
@@ -4186,6 +4284,7 @@ def deserialize_codes(errs: list[SerializedError]) -> list[ErrorTupleRaw]:
41864284
SCC_RESPONSE_MESSAGE: Final[Tag] = 103
41874285
SOURCES_DATA_MESSAGE: Final[Tag] = 104
41884286
SCCS_DATA_MESSAGE: Final[Tag] = 105
4287+
GRAPH_MESSAGE: Final[Tag] = 106
41894288

41904289

41914290
class AckMessage(IPCMessage):
@@ -4336,3 +4435,24 @@ def write(self, buf: WriteBuffer) -> None:
43364435
write_str_list(buf, sorted(scc.mod_ids))
43374436
write_int(buf, scc.id)
43384437
write_int_list(buf, sorted(scc.deps))
4438+
4439+
4440+
class GraphMessage(IPCMessage):
4441+
"""A message wrapping the build graph computed by the coordinator."""
4442+
4443+
def __init__(self, *, graph: Graph) -> None:
4444+
self.graph = graph
4445+
4446+
@classmethod
4447+
def read(cls, buf: ReadBuffer, manager: BuildManager | None = None) -> GraphMessage:
4448+
assert manager is not None
4449+
assert read_tag(buf) == GRAPH_MESSAGE
4450+
graph = {read_str_bare(buf): State.read(buf, manager) for _ in range(read_int_bare(buf))}
4451+
return GraphMessage(graph=graph)
4452+
4453+
def write(self, buf: WriteBuffer) -> None:
4454+
write_tag(buf, GRAPH_MESSAGE)
4455+
write_int_bare(buf, len(self.graph))
4456+
for mod_id, state in self.graph.items():
4457+
write_str_bare(buf, mod_id)
4458+
state.write(buf)

mypy/build_worker/worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from mypy.build import (
2929
AckMessage,
3030
BuildManager,
31+
GraphMessage,
3132
SccRequestMessage,
3233
SccResponseMessage,
3334
SccsDataMessage,
@@ -128,6 +129,15 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
128129

129130
# Notify worker we are done loading graph.
130131
send(server, AckMessage())
132+
133+
# Compare worker graph and coordinator, with parallel parser we will only use the latter.
134+
coordinator_graph = GraphMessage.read(receive(server), manager).graph
135+
assert coordinator_graph.keys() == graph.keys()
136+
for id in graph:
137+
assert graph[id].dependencies_set == coordinator_graph[id].dependencies_set
138+
assert graph[id].suppressed_set == coordinator_graph[id].suppressed_set
139+
send(server, AckMessage())
140+
131141
sccs = SccsDataMessage.read(receive(server)).sccs
132142
manager.scc_by_id = {scc.id: scc for scc in sccs}
133143
manager.top_order = [scc.id for scc in sccs]

0 commit comments

Comments
 (0)