@@ -257,6 +257,8 @@ struct
257257 clock : Clock .t ;
258258 mutable pending : Tcp.Id.Set .t ;
259259 mutable last_active_time : float ;
260+ (* Tasks that will be signalled if the endpoint is destroyed *)
261+ mutable on_destroy : unit Lwt .u Tcp.Id.Map .t ;
260262 }
261263 (* * A generic TCP/IP endpoint *)
262264
@@ -279,12 +281,17 @@ struct
279281
280282 let pending = Tcp.Id.Set. empty in
281283 let last_active_time = Unix. gettimeofday () in
284+ let on_destroy = Tcp.Id.Map. empty in
282285 let tcp_stack =
283286 { recorder; netif; ethif; arp; ipv4; icmpv4; udp4; tcp4; pending;
284- last_active_time; clock }
287+ last_active_time; clock; on_destroy }
285288 in
286289 Lwt. return tcp_stack
287290
291+ let destroy t =
292+ Tcp.Id.Map. iter (fun _ u -> Lwt. wakeup_later u () ) t.on_destroy;
293+ t.on_destroy < - Tcp.Id.Map. empty
294+
288295 let intercept_tcp_syn t ~id ~syn on_syn_callback (buf : Cstruct.t ) =
289296 if syn then begin
290297 if Tcp.Id.Set. mem id t.pending then begin
@@ -295,9 +302,14 @@ struct
295302 Lwt. return_unit
296303 end else begin
297304 t.pending < - Tcp.Id.Set. add id t.pending;
305+ (* Add a task to the "on_destroy" list which will be signalled if
306+ the Endpoint is disconnected from the switch and we should close
307+ connections. *)
308+ let close, close_request = Lwt. task () in
309+ t.on_destroy < - Tcp.Id.Map. add id close_request t.on_destroy;
298310 Lwt. finalize
299311 (fun () ->
300- on_syn_callback ()
312+ on_syn_callback close
301313 >> = fun listeners ->
302314 let src = Stack_tcp_wire. dst id in
303315 let dst = Stack_tcp_wire. src id in
@@ -319,7 +331,7 @@ struct
319331 Mirage_flow_lwt. Proxy (Clock )(Stack_tcp )(Host.Sockets.Stream. Tcp )
320332
321333 let input_tcp t ~id ~syn (ip , port ) (buf : Cstruct.t ) =
322- intercept_tcp_syn t ~id ~syn (fun () ->
334+ intercept_tcp_syn t ~id ~syn (fun close ->
323335 Host.Sockets.Stream.Tcp. connect (ip, port)
324336 >> = function
325337 | Error (`Msg m ) ->
@@ -341,9 +353,21 @@ struct
341353 Lwt. return_unit
342354 | Some socket ->
343355 Lwt. finalize (fun () ->
344- Proxy. proxy t.clock flow socket
356+ Lwt. pick [
357+ Lwt. map
358+ (function Error e -> Error (`Proxy e) | Ok x -> Ok x)
359+ (Proxy. proxy t.clock flow socket);
360+ Lwt. map
361+ (fun () -> Error `Close )
362+ close
363+ ]
345364 >> = function
346- | Error e ->
365+ | Error (`Close) ->
366+ Log. info (fun f ->
367+ f " %s proxy closed due to switch port disconnection"
368+ (Tcp.Flow. to_string tcp));
369+ Lwt. return_unit
370+ | Error (`Proxy e ) ->
347371 Log. debug (fun f ->
348372 f " %s proxy failed with %a"
349373 (Tcp.Flow. to_string tcp) Proxy. pp_error e);
@@ -354,6 +378,7 @@ struct
354378 Log. debug (fun f ->
355379 f " closing flow %s" (string_of_id tcp.Tcp.Flow. id));
356380 tcp.Tcp.Flow. socket < - None ;
381+ t.on_destroy < - Tcp.Id.Map. remove id t.on_destroy;
357382 Tcp.Flow. remove tcp.Tcp.Flow. id;
358383 Host.Sockets.Stream.Tcp. close socket
359384 )
@@ -479,9 +504,9 @@ struct
479504 let id =
480505 Stack_tcp_wire. v ~src_port: 53 ~dst: src ~src: dst ~dst_port: src_port
481506 in
482- Endpoint. intercept_tcp_syn t.endpoint ~id ~syn (fun () ->
507+ Endpoint. intercept_tcp_syn t.endpoint ~id ~syn (fun close ->
483508 ! dns >> = fun t ->
484- Dns_forwarder. handle_tcp ~t
509+ Dns_forwarder. handle_tcp ~t ~close
485510 ) raw
486511 > |= ok
487512
@@ -801,10 +826,11 @@ struct
801826 let now = Unix. gettimeofday () in
802827 let old_ips = IPMap. fold (fun ip endpoint acc ->
803828 let age = now -. endpoint.Endpoint. last_active_time in
804- if age > 300.0 then ip :: acc else acc
829+ if age > 300.0 then (ip, endpoint) :: acc else acc
805830 ) t.endpoints [] in
806- List. iter (fun ip ->
831+ List. iter (fun ( ip , endpoint ) ->
807832 Switch. remove t.switch ip;
833+ Endpoint. destroy endpoint;
808834 t.endpoints < - IPMap. remove ip t.endpoints
809835 ) old_ips;
810836 Lwt. return_unit
0 commit comments