50 #include <ilqgames/examples/minimally_invasive_receding_horizon_simulator.h> 51 #include <ilqgames/solver/ilq_solver.h> 52 #include <ilqgames/solver/problem.h> 53 #include <ilqgames/solver/solution_splicer.h> 54 #include <ilqgames/utils/solver_log.h> 55 #include <ilqgames/utils/strategy.h> 56 #include <ilqgames/utils/types.h> 58 #include <glog/logging.h> 66 using clock = std::chrono::system_clock;
68 std::vector<ActiveProblem> MinimallyInvasiveRecedingHorizonSimulator(
69 Time final_time, Time planner_runtime, GameSolver* original,
71 std::vector<std::shared_ptr<const SolverLog>>* original_logs,
72 std::vector<std::shared_ptr<const SolverLog>>* safety_logs) {
73 CHECK_NOTNULL(original);
74 CHECK_NOTNULL(safety);
75 CHECK_NOTNULL(original_logs);
76 CHECK_NOTNULL(safety_logs);
79 CHECK(original->GetProblem().InitialState().isApprox(
80 safety->GetProblem().InitialState(), constants::kSmallNumber));
81 CHECK_NEAR(original->GetProblem().InitialTime(),
82 safety->GetProblem().InitialTime(), constants::kSmallNumber);
86 const auto& dynamics = *original->GetProblem().Dynamics();
87 const auto& safety_dynamics = *safety->GetProblem().Dynamics();
88 CHECK(
typeid(dynamics) ==
typeid(safety_dynamics));
91 original_logs->clear();
97 auto solver_call_time = clock::now();
99 original_logs->push_back(original->Solve(&success));
102 std::chrono::duration<Time>(clock::now() - solver_call_time).count();
103 VLOG(1) <<
"Solved initial original problem in " << elapsed_time
104 <<
" seconds, with " << original_logs->back()->NumIterates()
107 solver_call_time = clock::now();
108 safety_logs->push_back(safety->Solve(&success));
111 std::chrono::duration<Time>(clock::now() - solver_call_time).count();
112 VLOG(1) <<
"Solved initial safety problem in " << elapsed_time
113 <<
" seconds, with " << safety_logs->back()->NumIterates()
118 SolutionSplicer splicer(*original_logs->front());
119 std::vector<ActiveProblem> active_problem = {ActiveProblem::ORIGINAL};
123 VectorXf x(original->GetProblem().InitialState());
124 Time t = original->GetProblem().InitialTime();
129 constexpr Time kExtraTime = 0.25;
132 if (t >= final_time ||
133 !splicer.ContainsTime(t + planner_runtime +
137 x = dynamics.Integrate(t - kExtraTime, t, x,
138 splicer.CurrentOperatingPoint(),
139 splicer.CurrentStrategies());
142 const bool current_active_problem_flag = active_problem.back();
143 auto current_active_problem =
144 (current_active_problem_flag == ActiveProblem::ORIGINAL) ? original
148 original->GetProblem().OverwriteSolution(splicer.CurrentOperatingPoint(),
149 splicer.CurrentStrategies());
150 safety->GetProblem().OverwriteSolution(splicer.CurrentOperatingPoint(),
151 splicer.CurrentStrategies());
154 original->GetProblem().ResetInitialState(
155 current_active_problem->GetProblem().InitialState());
156 safety->GetProblem().ResetInitialState(
157 current_active_problem->GetProblem().InitialState());
161 original->GetProblem().SetUpNextRecedingHorizon(x, t, planner_runtime);
162 safety->GetProblem().SetUpNextRecedingHorizon(x, t, planner_runtime);
164 solver_call_time = clock::now();
165 original_logs->push_back(original->Solve(&success, planner_runtime));
166 const Time original_elapsed_time =
167 std::chrono::duration<Time>(clock::now() - solver_call_time).count();
169 CHECK_LE(original_elapsed_time, planner_runtime);
170 VLOG(1) <<
"t = " << t <<
": Solved warm-started original problem in " 171 << original_elapsed_time <<
" seconds.";
173 solver_call_time = clock::now();
174 safety_logs->push_back(safety->Solve(&success, planner_runtime));
175 const Time safety_elapsed_time =
176 std::chrono::duration<Time>(clock::now() - solver_call_time).count();
178 CHECK_LE(safety_elapsed_time, planner_runtime);
179 VLOG(1) <<
"t = " << t <<
": Solved warm-started safety problem in " 180 << safety_elapsed_time <<
" seconds.";
183 elapsed_time = std::max(original_elapsed_time, safety_elapsed_time);
185 if (t >= final_time || !splicer.ContainsTime(t))
break;
188 x = dynamics.Integrate(t - elapsed_time, t, x,
189 splicer.CurrentOperatingPoint(),
190 splicer.CurrentStrategies());
193 if (!original_logs->back()->WasConverged())
194 VLOG(2) <<
"Original planner was not converged.";
195 if (!safety_logs->back()->WasConverged())
196 VLOG(2) <<
"Safety planner was not converged.";
201 constexpr
float kSafetyThreshold = -1.0;
202 const float p1_safety_cost = safety_logs->back()->TotalCosts().front();
204 if (p1_safety_cost > kSafetyThreshold ||
205 (safety_logs->back()->WasConverged() &&
206 !original_logs->back()->WasConverged())) {
207 active_problem.push_back(ActiveProblem::SAFETY);
208 splicer.Splice(*safety_logs->back());
209 VLOG(2) <<
"Using safety controller.";
211 active_problem.push_back(ActiveProblem::ORIGINAL);
212 if (original_logs->back()->WasConverged())
213 splicer.Splice(*original_logs->back());
217 return active_problem;