Skip to content

Commit 3b557e3

Browse files
committed
Refactor out shared LSTM/GGNN training loop.
github.com//issues/69
1 parent 9c4c442 commit 3b557e3

File tree

4 files changed

+65
-89
lines changed

4 files changed

+65
-89
lines changed

programl/task/dataflow/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ py_library(
5555
deps = [
5656
":graph_loader",
5757
"//programl/models:async_batch_builder",
58+
"//programl/models:base_batch_builder",
59+
"//programl/models:model",
5860
"//programl/models/ggnn",
5961
"//programl/proto:checkpoint_py",
6062
"//programl/proto:epoch_py",

programl/task/dataflow/dataflow.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import warnings
2222
from typing import Tuple
2323

24-
from labm8.py import app, pbutil
24+
from labm8.py import app, humanize, pbutil
2525
from sklearn.exceptions import UndefinedMetricWarning
2626

27+
from programl.models.base_batch_builder import BaseBatchBuilder
28+
from programl.models.model import Model
2729
from programl.proto import checkpoint_pb2, epoch_pb2
2830

2931
app.DEFINE_string(
@@ -208,3 +210,57 @@ def CreateLoggingDirectories(
208210
(log_dir / "checkpoints").mkdir()
209211
(log_dir / "graph_loader").mkdir()
210212
return log_dir
213+
214+
215+
def run_training_loop(
216+
log_dir: pathlib.Path,
217+
epochs,
218+
val_batches: BaseBatchBuilder,
219+
start_epoch_step: int,
220+
model: Model,
221+
val_graph_count: int,
222+
):
223+
for (
224+
epoch_step,
225+
(train_graph_count, train_graph_cumsum, train_batches),
226+
) in enumerate(epochs, start=start_epoch_step):
227+
start_time = time.time()
228+
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
229+
230+
train_results = model.RunBatches(
231+
epoch_pb2.TRAIN,
232+
train_batches,
233+
log_prefix=f"Train to {hr_graph_cumsum}",
234+
total_graph_count=train_graph_count,
235+
)
236+
val_results = model.RunBatches(
237+
epoch_pb2.VAL,
238+
val_batches.batches,
239+
log_prefix=f"Val at {hr_graph_cumsum}",
240+
total_graph_count=val_graph_count,
241+
)
242+
243+
# Write the epoch to file as an epoch list. This may seem redundant since
244+
# epoch list contains a single item, but it means that we can easily
245+
# concatenate a sequence of these epoch protos to produce a valid epoch
246+
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
247+
epoch = epoch_pb2.EpochList(
248+
epoch=[
249+
epoch_pb2.Epoch(
250+
walltime_seconds=time.time() - start_time,
251+
epoch_num=epoch_step,
252+
train_results=train_results,
253+
val_results=val_results,
254+
)
255+
]
256+
)
257+
print(epoch, end="")
258+
259+
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
260+
pbutil.ToFile(epoch, epoch_path)
261+
app.Log(1, "Wrote %s", epoch_path)
262+
263+
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
264+
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
265+
266+
return log_dir

programl/task/dataflow/ggnn.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -173,50 +173,9 @@ def TrainDataflowGGNN(
173173
)
174174
)
175175

176-
for (
177-
epoch_step,
178-
(train_graph_count, train_graph_cumsum, train_batches),
179-
) in enumerate(epochs, start=start_epoch_step):
180-
start_time = time.time()
181-
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
182-
183-
train_results = model.RunBatches(
184-
epoch_pb2.TRAIN,
185-
train_batches,
186-
log_prefix=f"Train to {hr_graph_cumsum}",
187-
total_graph_count=train_graph_count,
188-
)
189-
val_results = model.RunBatches(
190-
epoch_pb2.VAL,
191-
val_batches.batches,
192-
log_prefix=f"Val at {hr_graph_cumsum}",
193-
total_graph_count=val_graph_count,
194-
)
195-
196-
# Write the epoch to file as an epoch list. This may seem redundant since
197-
# epoch list contains a single item, but it means that we can easily
198-
# concatenate a sequence of these epoch protos to produce a valid epoch
199-
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
200-
epoch = epoch_pb2.EpochList(
201-
epoch=[
202-
epoch_pb2.Epoch(
203-
walltime_seconds=time.time() - start_time,
204-
epoch_num=epoch_step,
205-
train_results=train_results,
206-
val_results=val_results,
207-
)
208-
]
209-
)
210-
print(epoch, end="")
211-
212-
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
213-
pbutil.ToFile(epoch, epoch_path)
214-
app.Log(1, "Wrote %s", epoch_path)
215-
216-
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
217-
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
218-
219-
return log_dir
176+
return dataflow.run_training_loop(
177+
log_dir, epochs, val_batches, start_epoch_step, model, val_graph_count
178+
)
220179

221180

222181
def TestDataflowGGNN(

programl/task/dataflow/train_lstm.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -160,50 +160,9 @@ def TrainDataflowLSTM(
160160
)
161161
)
162162

163-
for (
164-
epoch_step,
165-
(train_graph_count, train_graph_cumsum, train_batches),
166-
) in enumerate(epochs, start=start_epoch_step):
167-
start_time = time.time()
168-
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
169-
170-
train_results = model.RunBatches(
171-
epoch_pb2.TRAIN,
172-
train_batches,
173-
log_prefix=f"Train to {hr_graph_cumsum}",
174-
total_graph_count=train_graph_count,
175-
)
176-
val_results = model.RunBatches(
177-
epoch_pb2.VAL,
178-
val_batches.batches,
179-
log_prefix=f"Val at {hr_graph_cumsum}",
180-
total_graph_count=FLAGS.val_graph_count,
181-
)
182-
183-
# Write the epoch to file as an epoch list. This may seem redundant since
184-
# epoch list contains a single item, but it means that we can easily
185-
# concatenate a sequence of these epoch protos to produce a valid epoch
186-
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
187-
epoch = epoch_pb2.EpochList(
188-
epoch=[
189-
epoch_pb2.Epoch(
190-
walltime_seconds=time.time() - start_time,
191-
epoch_num=epoch_step,
192-
train_results=train_results,
193-
val_results=val_results,
194-
)
195-
]
196-
)
197-
print(epoch, end="")
198-
199-
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
200-
pbutil.ToFile(epoch, epoch_path)
201-
app.Log(1, "Wrote %s", epoch_path)
202-
203-
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
204-
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
205-
206-
return log_dir
163+
return dataflow.run_training_loop(
164+
log_dir, epochs, val_batches, start_epoch_step, model, FLAGS.val_graph_count
165+
)
207166

208167

209168
def TestDataflowLSTM(

0 commit comments

Comments
 (0)