speed up 16.2 by 10x but its still 10 seconds >:(

This commit is contained in:
ryan 2023-11-23 11:17:41 -08:00
parent 1df5532c24
commit 5c0d7867a4

View File

@ -52,11 +52,15 @@ let parse_valves lines =
|> List.remove_assoc ~eq:String.equal from
|> List.iter (fun (valve, _) ->
Hashtbl.replace distances (from, valve) (dist_of valve)));
let get_dist f t = Option.get_exn_or "huge sad" @@ Hashtbl.get distances (f, t) in
let get_dist f t =
match Hashtbl.get distances (f, t) with
| Some dist -> dist
| None -> failwith @@ Printf.sprintf "cant find distance %s -> %s" f t
in
valves, get_dist
;;
let traverse get_dist =
let find_best get_dist =
let rec traverse' valves time elephant current =
let options =
valves
@ -77,10 +81,38 @@ let traverse get_dist =
traverse'
;;
module StrSet = Set.Make (String)
let traverse get_dist valves current time released opened =
let states = Hashtbl.create 1000 in
let rec traverse' current time released opened =
Hashtbl.update
states
~f:(fun _ current ->
match current with
| Some r when r >= released -> Some r
| _ -> Some released)
~k:opened;
valves
|> List.iter (fun (name, flow) ->
if not @@ StrSet.mem name opened
then (
let dist = get_dist current name in
let time' = time - (dist + 1) in
if time' > 0
then (
let released' = released + (flow * time') in
let opened' = StrSet.add name opened in
traverse' name time' released' opened')))
in
traverse' current time released opened;
states
;;
let%expect_test "example" =
let valves, get_dist = parse_valves example_lines in
Printf.printf "part 1: %i\n" @@ traverse get_dist valves 30 false "AA";
Printf.printf "part 2: %i\n" @@ traverse get_dist valves 26 true "AA";
Printf.printf "part 1: %i\n" @@ find_best get_dist valves 30 false "AA";
Printf.printf "part 2: %i\n" @@ find_best get_dist valves 26 true "AA";
[%expect {|
part 1: 1651
part 2: 1707 |}]
@ -88,25 +120,34 @@ let%expect_test "example" =
let%expect_test "Day 16.1 example" =
let valves, get_dist = parse_valves example_lines in
Printf.printf "%i\n" @@ traverse get_dist valves 30 false "AA";
Printf.printf "%i\n" @@ find_best get_dist valves 30 false "AA";
[%expect {| 1651 |}]
;;
let%expect_test "Day 16.2 example" =
let valves, get_dist = parse_valves example_lines in
Printf.printf "%i\n" @@ traverse get_dist valves 26 true "AA";
Printf.printf "%i\n" @@ find_best get_dist valves 26 true "AA";
[%expect {| 1707 |}]
;;
let%expect_test "Day 16.1" =
let valves, get_dist = parse_valves @@ Utils.lines_of_input 16 in
Printf.printf "%i\n" @@ traverse get_dist valves 30 false "AA";
Printf.printf "%i\n" @@ find_best get_dist valves 30 false "AA";
[%expect {| 2265 |}]
;;
(* took 106.465 sec >:( *)
(* let%expect_test "Day 16.2" = *)
(* let valves, get_dist = parse_valves @@ Utils.lines_of_input 16 in *)
(* Printf.printf "%i\n" @@ traverse get_dist valves 26 true "AA"; *)
(* [%expect {| 2811 |}] *)
(* ;; *)
(* still takes 10 seconds >:( *)
let%expect_test "Day 16.2" =
let valves, get_dist = parse_valves @@ Utils.lines_of_input 16 in
let states = traverse get_dist valves "AA" 26 0 StrSet.empty in
let best =
states
|> Iter.of_hashtbl
|> Iter.diagonal
|> Iter.filter_map (fun ((a_set, a_score), (b_set, b_score)) ->
if StrSet.disjoint a_set b_set then Some (a_score + b_score) else None)
|> Iter.fold Int.max 0
in
Printf.printf "%i\n" @@ best;
[%expect {| 2811 |}]
;;