Skip to content

Commit a621d08

Browse files
committed
Refactor handleResponseWebhooksAndCompletion to eliminate duplicate response handling
- Consolidate response update logic into single operation to avoid duplicate overwrites - Extract webhook sending logic into dedicated sendStatusWebhook function - Extract terminal completion handling into handleTerminalCompletion function - Ensure proper lock context around populateFromRequest calls accessing pending.request - Maintain race condition protection by using local response copy for safeSend - Remove TODO comment about function being "a mess" - now clean and maintainable This eliminates the scattered response object handling and reduces lock contention while preserving all existing functionality and thread safety guarantees.
1 parent b48d846 commit a621d08

File tree

1 file changed

+63
-56
lines changed

1 file changed

+63
-56
lines changed

internal/runner/runner.go

Lines changed: 63 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -212,85 +212,92 @@ func (r *Runner) processResponseOutput(response *PredictionResponse, pending *Pe
212212
}
213213

214214
// handleResponseWebhooksAndCompletion sends webhooks and handles prediction completion
215-
// TODO: This function is a mess. It needs a hard look to make ti less ... all over the place, duplicate response object handling etc., it works for now...but
216-
// it's a pile of ick.
217215
func (r *Runner) handleResponseWebhooksAndCompletion(response *PredictionResponse, predictionID string, pending *PendingPrediction, log *logging.SugaredLogger) {
218-
// Update pending prediction's response data, preserving accumulated logs
216+
// Handle legacy compatibility: change "starting" to "processing" for webhook purposes
217+
webhookStatus := response.Status
218+
if response.Status == PredictionStarting {
219+
webhookStatus = PredictionProcessing
220+
}
221+
222+
// Update pending response once with all necessary fields
219223
pending.mu.Lock()
220224
existingLogs := pending.response.Logs
221225
pending.response = *response
222-
// Preserve accumulated logs if they exist and response doesn't have logs
226+
pending.response.Status = webhookStatus
227+
228+
// Preserve accumulated logs if new response doesn't have them
223229
if len(existingLogs) > 0 && len(response.Logs) == 0 {
224230
pending.response.Logs = existingLogs
225231
}
226-
// Restore timestamps and other request fields that were lost when we overwrote the response
232+
233+
// Restore request-derived fields and finalize if terminal
227234
pending.response.populateFromRequest(pending.request)
235+
completed := pending.response.Status.IsCompleted()
236+
if completed {
237+
if err := pending.response.finalizeResponse(); err != nil {
238+
log.Errorw("failed to finalize response", "error", err)
239+
}
240+
}
228241
pending.mu.Unlock()
229242

230-
// Send webhooks based on prediction status
231-
switch response.Status {
243+
// Send appropriate webhooks
244+
r.sendStatusWebhook(pending, response, webhookStatus, log)
245+
246+
// Handle terminal completion
247+
if completed {
248+
r.handleTerminalCompletion(response, pending, predictionID, log)
249+
}
250+
}
251+
252+
// sendStatusWebhook sends webhooks based on prediction status
253+
func (r *Runner) sendStatusWebhook(pending *PendingPrediction, response *PredictionResponse, status PredictionStatus, log *logging.SugaredLogger) {
254+
switch status {
232255
case PredictionStarting:
233-
log.Debugw("prediction started", "id", response.ID, "status", response.Status)
234-
// Compat: legacy Cog never sends "start" event - change status to processing
235-
response.Status = PredictionProcessing
236-
pending.mu.Lock()
237-
pending.response.Status = PredictionProcessing
238-
pending.mu.Unlock()
239-
// Send start webhook async (intermediary)
256+
// This case shouldn't happen due to compatibility transformation above
257+
log.Debugw("prediction started", "id", response.ID, "status", status)
240258
go func() { _ = pending.sendWebhook(webhook.EventStart) }()
241259

242260
case PredictionProcessing:
243-
log.Debugw("prediction processing", "id", response.ID, "status", response.Status)
244-
// Send output/logs webhook async (intermediary)
245-
if response.Output != nil {
246-
go func() { _ = pending.sendWebhook(webhook.EventOutput) }()
261+
if response.Status == PredictionStarting {
262+
log.Debugw("prediction started", "id", response.ID, "status", "starting->processing")
263+
go func() { _ = pending.sendWebhook(webhook.EventStart) }()
247264
} else {
248-
go func() { _ = pending.sendWebhook(webhook.EventLogs) }()
265+
log.Debugw("prediction processing", "id", response.ID, "status", status)
266+
if response.Output != nil {
267+
go func() { _ = pending.sendWebhook(webhook.EventOutput) }()
268+
} else {
269+
go func() { _ = pending.sendWebhook(webhook.EventLogs) }()
270+
}
249271
}
250272
}
273+
}
251274

252-
// Always update pending response state, preserving accumulated logs again
275+
// handleTerminalCompletion handles cleanup and response sending for completed predictions
276+
func (r *Runner) handleTerminalCompletion(response *PredictionResponse, pending *PendingPrediction, predictionID string, log *logging.SugaredLogger) {
277+
log.Infow("prediction completed", "id", response.ID, "status", response.Status)
278+
279+
// Prepare local response copy with all fields populated and finalized
253280
pending.mu.Lock()
254-
existingLogs = pending.response.Logs
255-
pending.response = *response
256-
// Preserve accumulated logs if they exist and response doesn't have logs
257-
if len(existingLogs) > 0 && len(response.Logs) == 0 {
258-
pending.response.Logs = existingLogs
259-
}
260-
// Restore timestamps and other request fields that were lost when we overwrote the response
261-
pending.response.populateFromRequest(pending.request)
262-
response.populateFromRequest(pending.request)
263-
completed := pending.response.Status.IsCompleted()
264-
if completed {
265-
if err := pending.response.finalizeResponse(); err != nil {
266-
log.Errorw("failed to finalize response", "error", err)
267-
}
268-
}
281+
finalResponse := *response
282+
finalResponse.populateFromRequest(pending.request)
269283
pending.mu.Unlock()
270284

271-
// Handle terminal vs non-terminal states
272-
if completed {
273-
log.Infow("prediction completed", "id", response.ID, "status", response.Status)
274-
// Finalize the local response copy to avoid race conditions with pending.response
275-
if err := response.finalizeResponse(); err != nil {
276-
log.Errorw("failed to finalize response", "error", err)
277-
}
278-
// Send response and close channel - use local response copy to avoid race where
279-
// pending.response could be modified by another goroutine before safeSend completes
280-
pending.safeSend(*response)
281-
pending.safeClose()
285+
if err := finalResponse.finalizeResponse(); err != nil {
286+
log.Errorw("failed to finalize response", "error", err)
287+
}
282288

283-
// Clean up input paths for completed prediction
284-
for _, inputPath := range pending.inputPaths {
285-
if err := os.Remove(inputPath); err != nil {
286-
log.Errorw("failed to remove input path", "path", inputPath, "error", err)
287-
}
288-
}
289+
// Send response and close channel - use local copy to avoid race conditions
290+
pending.safeSend(finalResponse)
291+
pending.safeClose()
289292

290-
// Watcher exits - manager defer will handle webhook and cleanup
291-
log.Tracew("prediction completed, watcher exiting", "prediction_id", predictionID)
292-
return
293+
// Clean up input paths
294+
for _, inputPath := range pending.inputPaths {
295+
if err := os.Remove(inputPath); err != nil {
296+
log.Errorw("failed to remove input path", "path", inputPath, "error", err)
297+
}
293298
}
299+
300+
log.Tracew("prediction completed, watcher exiting", "prediction_id", predictionID)
294301
}
295302

296303
type (
@@ -878,7 +885,7 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, *Predi
878885

879886
log.Tracew("returning prediction channel", "prediction_id", req.ID)
880887
initialResponse := &PredictionResponse{
881-
Status: "starting",
888+
Status: PredictionStarting,
882889
}
883890
initialResponse.populateFromRequest(req)
884891
return pending.c, initialResponse, nil

0 commit comments

Comments
 (0)