30x speedup on 16.2 from using bitmasks as a set!!

This commit is contained in:
ryan 2023-12-04 23:28:25 -08:00
parent 4e57ad4072
commit 94f4badcc4

View File

@ -15,12 +15,10 @@ Valve JJ has flow rate=21; tunnel leads to valve II
|}
;;
type room =
{ name : string
; flow : int
; tunnels : string list
type valve =
{ flow : int
; mask : int
}
[@@deriving show { with_path = false }]
let parse_valves lines =
let rooms = Hashtbl.create 100 in
@ -29,22 +27,28 @@ let parse_valves lines =
let int_matches = List.map int_of_string Re.(matches (Pcre.regexp {|\d+|}) line) in
let valve_matches = Re.(matches (Pcre.regexp {|[A-Z]{2}|}) line) in
match int_matches, valve_matches with
| [ flow ], name :: tunnels -> Hashtbl.add rooms name { name; flow; tunnels }
| [ flow ], name :: tunnels -> Hashtbl.add rooms name (flow, tunnels)
| _ -> failwith "parse error");
let valves =
Hashtbl.to_list rooms
|> List.filter_map (fun (_, { name; flow; tunnels = _ }) ->
if flow > 0 then Some (name, flow) else None)
|> List.filter (fun (_, (flow, _)) -> flow > 0)
|> List.mapi (fun i (name, (flow, _)) ->
let mask = Int.(1 lsl i) in
name, { flow; mask })
in
let distances = Hashtbl.create 100 in
valves @ [ "AA", 0 ]
|> List.iter (fun (from, _) ->
valves
|> List.map (fun (valve, _) -> valve)
|> List.cons "AA"
|> List.iter (fun from ->
let dist_of =
Fun.(
Utils.djikstra_hash
~get_weight:(fun _ _ -> 1)
~get_neighbors:(fun room ->
(Option.get_exn_or "sad" @@ Hashtbl.get rooms room).tunnels)
match Hashtbl.get rooms room with
| Some (_, tunnels) -> tunnels
| None -> failwith "sad")
~start_nodes:[ from ]
%> Option.get_exn_or "big sad")
in
@ -60,14 +64,12 @@ let parse_valves lines =
valves, get_dist
;;
module StrSet = Set.Make (String)
let traverse
get_dist
valves
(valves : (string * valve) list)
?(current = "AA")
?(released = 0)
?(opened = StrSet.empty)
?(opened = 0)
time
=
let states = Hashtbl.create 1000 in
@ -80,15 +82,15 @@ let traverse
| _ -> Some released)
~k:opened;
valves
|> List.iter (fun (name, flow) ->
if not @@ StrSet.mem name opened
|> List.iter (fun (name, { flow; mask }) ->
if Int.(mask land opened) = 0
then (
let dist = get_dist current name in
let time' = time - (dist + 1) in
if time' > 0
then (
let released' = released + (flow * time') in
let opened' = StrSet.add name opened in
let opened' = Int.(mask lor opened) in
traverse' name time' released' opened')))
in
traverse' current time released opened;
@ -120,7 +122,7 @@ let solve_part_2 lines =
let states = traverse get_dist valves 26 in
states
|> Iter.filter_map (fun (a_set, a_score) ->
match Iter.find_pred (fun (b_set, _) -> StrSet.disjoint a_set b_set) states with
match Iter.find_pred (fun (b_set, _) -> Int.(a_set land b_set) = 0) states with
| Some (_, b_score) -> Some (a_score + b_score)
| None -> None)
|> Iter.fold Int.max 0
@ -131,7 +133,7 @@ let%expect_test "Day 16.2 example" =
[%expect {| 1707 |}]
;;
(* down to 1.5 seconds! bitmasks ???? *)
(* god damn bitmasks are fast, 30x speedup from StrSet! down to 50ms now *)
let%expect_test "Day 16.2" =
Printf.printf "%i\n" @@ solve_part_2 @@ Utils.lines_of_input 16;
[%expect {| 2811 |}]