@@ -7,6 +7,9 @@ let src =
77
88module Log = (val Logs. src_log src : Logs.LOG )
99
10+ (* Maximum size of a UDP DNS response before we must truncate *)
11+ let max_udp_response = 512
12+
1013module Config = struct
1114 type t = [
1215 | `Upstream of Dns_forward.Config .t
@@ -326,17 +329,35 @@ struct
326329 Log. info (fun f -> f " Will use the host's DNS resolver" );
327330 Lwt. return { local_ip; builtin_names; resolver = Host }
328331
332+ let search f low high =
333+ if not (f low)
334+ then None (* none of the elements satisfy the predicate *)
335+ else
336+ let rec loop low high =
337+ if low = high
338+ then Some low
339+ else
340+ let mid = (low + high + 1 ) / 2 in
341+ (* since low <> high, mid <> low but it might be mid = high *)
342+ if f mid
343+ then loop mid high
344+ else
345+ if mid = high
346+ then Some low
347+ else loop low mid in
348+ loop low high
349+
329350 let answer t is_tcp buf =
330351 let open Dns.Packet in
331352 let len = Cstruct. len buf in
332353 match Dns.Protocol.Server. parse (Cstruct. sub buf 0 len) with
333354 | None ->
334355 Lwt. return (Error (`Msg " failed to parse DNS packet" ))
335356 | Some ({ questions = [ question ]; _ } as request ) ->
336- let reply answers =
357+ let reply ~ tc answers =
337358 let id = request.id in
338359 let detail =
339- { request.detail with Dns.Packet. qr = Dns.Packet. Response ; ra = true }
360+ { request.detail with Dns.Packet. qr = Dns.Packet. Response ; ra = true ; tc }
340361 in
341362 let questions = request.questions in
342363 let authorities = [] and additionals = [] in
@@ -354,31 +375,70 @@ struct
354375 { Dns.Packet. id; detail; questions; answers; authorities;
355376 additionals }
356377 in
378+ let marshal_reply answers =
379+ let buf = marshal @@ reply ~tc: false answers in
380+ if is_tcp
381+ then Some buf (* No need to truncate for TCP *)
382+ else begin
383+ (* If the packet is too big then set the TC bit and truncate by dropping answers *)
384+ let take n from =
385+ let rec loop n from acc = match n, from with
386+ | 0 , _ -> acc
387+ | _ , [] -> acc
388+ | n , x :: xs -> loop (n - 1 ) xs (x :: acc) in
389+ List. rev @@ loop n from [] in
390+ if Cstruct. len buf > max_udp_response then begin
391+ match search (fun num ->
392+ (* use only the first 'num' answers *)
393+ Cstruct. len (marshal @@ reply ~tc: true (take num answers)) < = max_udp_response
394+ ) 0 (List. length answers) with
395+ | None -> None
396+ | Some num -> Some (marshal @@ reply ~tc: true (take num answers))
397+ end
398+ else Some buf
399+ end in
357400 begin
358401 (* Consider the builtins (from the command-line) to have higher priority
359402 than the addresses in the /etc/hosts file. *)
360403 match try_builtins t.builtin_names question with
361404 | `Does_not_exist ->
362- Lwt. return (Ok (marshal nxdomain))
405+ Lwt. return (Ok (Some ( marshal nxdomain) ))
363406 | `Answers answers ->
364- Lwt. return (Ok (marshal @@ reply answers))
407+ Lwt. return (Ok (marshal_reply answers))
365408 | `Dont_know ->
366409 match try_etc_hosts question with
367410 | Some answers ->
368- Lwt. return (Ok (marshal @@ reply answers))
411+ Lwt. return (Ok (marshal_reply answers))
369412 | None ->
370413 match is_tcp, t.resolver with
371414 | true , Upstream { dns_tcp_resolver; _ } ->
372- Dns_tcp_resolver. answer buf dns_tcp_resolver
415+ begin
416+ Dns_tcp_resolver. answer buf dns_tcp_resolver
417+ >> = function
418+ | Error e -> Lwt. return (Error e)
419+ | Ok buf -> Lwt. return (Ok (Some buf))
420+ end
373421 | false , Upstream { dns_udp_resolver; _ } ->
374- Dns_udp_resolver. answer buf dns_udp_resolver
422+ begin
423+ Dns_udp_resolver. answer buf dns_udp_resolver
424+ >> = function
425+ | Error e -> Lwt. return (Error e)
426+ | Ok buf ->
427+ (* We need to parse and re-marshal so we can set the TC bit and truncate *)
428+ begin match Dns.Protocol.Server. parse buf with
429+ | None ->
430+ Lwt. return (Error (`Msg " Failed to unmarshal DNS response from upstream" ))
431+ | Some { answers; _ } ->
432+ Lwt. return (Ok (marshal_reply answers))
433+ end
434+ end
375435 | _ , Host ->
376436 D. resolve question
377437 >> = function
378438 | [] ->
379- Lwt. return (Ok (marshal nxdomain))
439+ Lwt. return (Ok (Some ( marshal nxdomain) ))
380440 | answers ->
381- Lwt. return (Ok (marshal @@ reply answers))
441+ Lwt. return (Ok (marshal_reply answers))
382442 end
383443 | _ ->
384444 Lwt. return (Error (`Msg " DNS packet had multiple questions" ))
@@ -395,7 +455,10 @@ struct
395455 | Error (`Msg m ) ->
396456 Log. warn (fun f -> f " %s lookup failed: %s" (describe buf) m);
397457 Lwt. return (Ok () )
398- | Ok buffer ->
458+ | Ok None ->
459+ Log. err (fun f -> f " %s unable to marshal response" (describe buf));
460+ Lwt. return (Ok () )
461+ | Ok (Some buffer ) ->
399462 Udp. write ~src_port: 53 ~dst: src ~dst_port: src_port udp buffer
400463
401464 let handle_tcp ~t =
@@ -414,7 +477,10 @@ struct
414477 | Error (`Msg m ) ->
415478 Log. warn (fun f -> f " %s lookup failed: %s" (describe request) m);
416479 Lwt. return_unit
417- | Ok buffer ->
480+ | Ok None ->
481+ Log. err (fun f -> f " %s unable to marshal response to" (describe request));
482+ Lwt. return_unit
483+ | Ok (Some buffer ) ->
418484 Dns_tcp_framing. write packets buffer >> = function
419485 | Error (`Msg m ) ->
420486 Log. warn (fun f ->
0 commit comments