@@ -212,85 +212,92 @@ func (r *Runner) processResponseOutput(response *PredictionResponse, pending *Pe
212212}
213213
214214// handleResponseWebhooksAndCompletion sends webhooks and handles prediction completion
215- // TODO: This function is a mess. It needs a hard look to make ti less ... all over the place, duplicate response object handling etc., it works for now...but
216- // it's a pile of ick.
217215func (r * Runner ) handleResponseWebhooksAndCompletion (response * PredictionResponse , predictionID string , pending * PendingPrediction , log * logging.SugaredLogger ) {
218- // Update pending prediction's response data, preserving accumulated logs
216+ // Handle legacy compatibility: change "starting" to "processing" for webhook purposes
217+ webhookStatus := response .Status
218+ if response .Status == PredictionStarting {
219+ webhookStatus = PredictionProcessing
220+ }
221+
222+ // Update pending response once with all necessary fields
219223 pending .mu .Lock ()
220224 existingLogs := pending .response .Logs
221225 pending .response = * response
222- // Preserve accumulated logs if they exist and response doesn't have logs
226+ pending .response .Status = webhookStatus
227+
228+ // Preserve accumulated logs if new response doesn't have them
223229 if len (existingLogs ) > 0 && len (response .Logs ) == 0 {
224230 pending .response .Logs = existingLogs
225231 }
226- // Restore timestamps and other request fields that were lost when we overwrote the response
232+
233+ // Restore request-derived fields and finalize if terminal
227234 pending .response .populateFromRequest (pending .request )
235+ completed := pending .response .Status .IsCompleted ()
236+ if completed {
237+ if err := pending .response .finalizeResponse (); err != nil {
238+ log .Errorw ("failed to finalize response" , "error" , err )
239+ }
240+ }
228241 pending .mu .Unlock ()
229242
230- // Send webhooks based on prediction status
231- switch response .Status {
243+ // Send appropriate webhooks
244+ r .sendStatusWebhook (pending , response , webhookStatus , log )
245+
246+ // Handle terminal completion
247+ if completed {
248+ r .handleTerminalCompletion (response , pending , predictionID , log )
249+ }
250+ }
251+
252+ // sendStatusWebhook sends webhooks based on prediction status
253+ func (r * Runner ) sendStatusWebhook (pending * PendingPrediction , response * PredictionResponse , status PredictionStatus , log * logging.SugaredLogger ) {
254+ switch status {
232255 case PredictionStarting :
233- log .Debugw ("prediction started" , "id" , response .ID , "status" , response .Status )
234- // Compat: legacy Cog never sends "start" event - change status to processing
235- response .Status = PredictionProcessing
236- pending .mu .Lock ()
237- pending .response .Status = PredictionProcessing
238- pending .mu .Unlock ()
239- // Send start webhook async (intermediary)
256+ // This case shouldn't happen due to compatibility transformation above
257+ log .Debugw ("prediction started" , "id" , response .ID , "status" , status )
240258 go func () { _ = pending .sendWebhook (webhook .EventStart ) }()
241259
242260 case PredictionProcessing :
243- log .Debugw ("prediction processing" , "id" , response .ID , "status" , response .Status )
244- // Send output/logs webhook async (intermediary)
245- if response .Output != nil {
246- go func () { _ = pending .sendWebhook (webhook .EventOutput ) }()
261+ if response .Status == PredictionStarting {
262+ log .Debugw ("prediction started" , "id" , response .ID , "status" , "starting->processing" )
263+ go func () { _ = pending .sendWebhook (webhook .EventStart ) }()
247264 } else {
248- go func () { _ = pending .sendWebhook (webhook .EventLogs ) }()
265+ log .Debugw ("prediction processing" , "id" , response .ID , "status" , status )
266+ if response .Output != nil {
267+ go func () { _ = pending .sendWebhook (webhook .EventOutput ) }()
268+ } else {
269+ go func () { _ = pending .sendWebhook (webhook .EventLogs ) }()
270+ }
249271 }
250272 }
273+ }
251274
252- // Always update pending response state, preserving accumulated logs again
275+ // handleTerminalCompletion handles cleanup and response sending for completed predictions
276+ func (r * Runner ) handleTerminalCompletion (response * PredictionResponse , pending * PendingPrediction , predictionID string , log * logging.SugaredLogger ) {
277+ log .Infow ("prediction completed" , "id" , response .ID , "status" , response .Status )
278+
279+ // Prepare local response copy with all fields populated and finalized
253280 pending .mu .Lock ()
254- existingLogs = pending .response .Logs
255- pending .response = * response
256- // Preserve accumulated logs if they exist and response doesn't have logs
257- if len (existingLogs ) > 0 && len (response .Logs ) == 0 {
258- pending .response .Logs = existingLogs
259- }
260- // Restore timestamps and other request fields that were lost when we overwrote the response
261- pending .response .populateFromRequest (pending .request )
262- response .populateFromRequest (pending .request )
263- completed := pending .response .Status .IsCompleted ()
264- if completed {
265- if err := pending .response .finalizeResponse (); err != nil {
266- log .Errorw ("failed to finalize response" , "error" , err )
267- }
268- }
281+ finalResponse := * response
282+ finalResponse .populateFromRequest (pending .request )
269283 pending .mu .Unlock ()
270284
271- // Handle terminal vs non-terminal states
272- if completed {
273- log .Infow ("prediction completed" , "id" , response .ID , "status" , response .Status )
274- // Finalize the local response copy to avoid race conditions with pending.response
275- if err := response .finalizeResponse (); err != nil {
276- log .Errorw ("failed to finalize response" , "error" , err )
277- }
278- // Send response and close channel - use local response copy to avoid race where
279- // pending.response could be modified by another goroutine before safeSend completes
280- pending .safeSend (* response )
281- pending .safeClose ()
285+ if err := finalResponse .finalizeResponse (); err != nil {
286+ log .Errorw ("failed to finalize response" , "error" , err )
287+ }
282288
283- // Clean up input paths for completed prediction
284- for _ , inputPath := range pending .inputPaths {
285- if err := os .Remove (inputPath ); err != nil {
286- log .Errorw ("failed to remove input path" , "path" , inputPath , "error" , err )
287- }
288- }
289+ // Send response and close channel - use local copy to avoid race conditions
290+ pending .safeSend (finalResponse )
291+ pending .safeClose ()
289292
290- // Watcher exits - manager defer will handle webhook and cleanup
291- log .Tracew ("prediction completed, watcher exiting" , "prediction_id" , predictionID )
292- return
293+ // Clean up input paths
294+ for _ , inputPath := range pending .inputPaths {
295+ if err := os .Remove (inputPath ); err != nil {
296+ log .Errorw ("failed to remove input path" , "path" , inputPath , "error" , err )
297+ }
293298 }
299+
300+ log .Tracew ("prediction completed, watcher exiting" , "prediction_id" , predictionID )
294301}
295302
296303type (
@@ -878,7 +885,7 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, *Predi
878885
879886 log .Tracew ("returning prediction channel" , "prediction_id" , req .ID )
880887 initialResponse := & PredictionResponse {
881- Status : "starting" ,
888+ Status : PredictionStarting ,
882889 }
883890 initialResponse .populateFromRequest (req )
884891 return pending .c , initialResponse , nil
0 commit comments