30x speedup on 16.2 from using bitmasks as a set!!
This commit is contained in:
parent
4e57ad4072
commit
94f4badcc4
42
src/day16.ml
42
src/day16.ml
|
|
@ -15,12 +15,10 @@ Valve JJ has flow rate=21; tunnel leads to valve II
|
|||
|}
|
||||
;;
|
||||
|
||||
type room =
|
||||
{ name : string
|
||||
; flow : int
|
||||
; tunnels : string list
|
||||
type valve =
|
||||
{ flow : int
|
||||
; mask : int
|
||||
}
|
||||
[@@deriving show { with_path = false }]
|
||||
|
||||
let parse_valves lines =
|
||||
let rooms = Hashtbl.create 100 in
|
||||
|
|
@ -29,22 +27,28 @@ let parse_valves lines =
|
|||
let int_matches = List.map int_of_string Re.(matches (Pcre.regexp {|\d+|}) line) in
|
||||
let valve_matches = Re.(matches (Pcre.regexp {|[A-Z]{2}|}) line) in
|
||||
match int_matches, valve_matches with
|
||||
| [ flow ], name :: tunnels -> Hashtbl.add rooms name { name; flow; tunnels }
|
||||
| [ flow ], name :: tunnels -> Hashtbl.add rooms name (flow, tunnels)
|
||||
| _ -> failwith "parse error");
|
||||
let valves =
|
||||
Hashtbl.to_list rooms
|
||||
|> List.filter_map (fun (_, { name; flow; tunnels = _ }) ->
|
||||
if flow > 0 then Some (name, flow) else None)
|
||||
|> List.filter (fun (_, (flow, _)) -> flow > 0)
|
||||
|> List.mapi (fun i (name, (flow, _)) ->
|
||||
let mask = Int.(1 lsl i) in
|
||||
name, { flow; mask })
|
||||
in
|
||||
let distances = Hashtbl.create 100 in
|
||||
valves @ [ "AA", 0 ]
|
||||
|> List.iter (fun (from, _) ->
|
||||
valves
|
||||
|> List.map (fun (valve, _) -> valve)
|
||||
|> List.cons "AA"
|
||||
|> List.iter (fun from ->
|
||||
let dist_of =
|
||||
Fun.(
|
||||
Utils.djikstra_hash
|
||||
~get_weight:(fun _ _ -> 1)
|
||||
~get_neighbors:(fun room ->
|
||||
(Option.get_exn_or "sad" @@ Hashtbl.get rooms room).tunnels)
|
||||
match Hashtbl.get rooms room with
|
||||
| Some (_, tunnels) -> tunnels
|
||||
| None -> failwith "sad")
|
||||
~start_nodes:[ from ]
|
||||
%> Option.get_exn_or "big sad")
|
||||
in
|
||||
|
|
@ -60,14 +64,12 @@ let parse_valves lines =
|
|||
valves, get_dist
|
||||
;;
|
||||
|
||||
module StrSet = Set.Make (String)
|
||||
|
||||
let traverse
|
||||
get_dist
|
||||
valves
|
||||
(valves : (string * valve) list)
|
||||
?(current = "AA")
|
||||
?(released = 0)
|
||||
?(opened = StrSet.empty)
|
||||
?(opened = 0)
|
||||
time
|
||||
=
|
||||
let states = Hashtbl.create 1000 in
|
||||
|
|
@ -80,15 +82,15 @@ let traverse
|
|||
| _ -> Some released)
|
||||
~k:opened;
|
||||
valves
|
||||
|> List.iter (fun (name, flow) ->
|
||||
if not @@ StrSet.mem name opened
|
||||
|> List.iter (fun (name, { flow; mask }) ->
|
||||
if Int.(mask land opened) = 0
|
||||
then (
|
||||
let dist = get_dist current name in
|
||||
let time' = time - (dist + 1) in
|
||||
if time' > 0
|
||||
then (
|
||||
let released' = released + (flow * time') in
|
||||
let opened' = StrSet.add name opened in
|
||||
let opened' = Int.(mask lor opened) in
|
||||
traverse' name time' released' opened')))
|
||||
in
|
||||
traverse' current time released opened;
|
||||
|
|
@ -120,7 +122,7 @@ let solve_part_2 lines =
|
|||
let states = traverse get_dist valves 26 in
|
||||
states
|
||||
|> Iter.filter_map (fun (a_set, a_score) ->
|
||||
match Iter.find_pred (fun (b_set, _) -> StrSet.disjoint a_set b_set) states with
|
||||
match Iter.find_pred (fun (b_set, _) -> Int.(a_set land b_set) = 0) states with
|
||||
| Some (_, b_score) -> Some (a_score + b_score)
|
||||
| None -> None)
|
||||
|> Iter.fold Int.max 0
|
||||
|
|
@ -131,7 +133,7 @@ let%expect_test "Day 16.2 example" =
|
|||
[%expect {| 1707 |}]
|
||||
;;
|
||||
|
||||
(* down to 1.5 seconds! bitmasks ???? *)
|
||||
(* god damn bitmasks are fast, 30x speedup from StrSet! down to 50ms now *)
|
||||
let%expect_test "Day 16.2" =
|
||||
Printf.printf "%i\n" @@ solve_part_2 @@ Utils.lines_of_input 16;
|
||||
[%expect {| 2811 |}]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user