diff --git a/_scripts/run_stereo_study.py b/_scripts/run_stereo_study.py index 2b8aa29..90991b9 100644 --- a/_scripts/run_stereo_study.py +++ b/_scripts/run_stereo_study.py @@ -222,19 +222,19 @@ def run_stereo_1d(): print("theta gt:", lifter.theta) - theta_hat, info_local = lifter.local_solver(t_init=lifter.theta, y=y) - print("theta global:", theta_hat, "cost:", info_local["cost"]) + theta_hat, info_local, cost_local = lifter.local_solver(t_init=lifter.theta, y=y) + print("theta global:", theta_hat, "cost:", cost_local) # sanity check x_hat = lifter.get_x(theta=theta_hat) - assert abs(x_hat.T @ Q @ x_hat - info_local["cost"]) / info_local["cost"] < 1e-10 + assert abs(x_hat.T @ Q @ x_hat - cost_local) / cost_local < 1e-10 for A_list, label in zip([A_known, A_all], ["known", "learned"]): lifter.test_constraints(A_list) Constraints = [(lifter.get_A0(), 1.0)] + [(Ai, 0.0) for Ai in A_list] X, info_sdp = solve_sdp_cvxpy(Q=Q, Constraints=Constraints, verbose=False) x_round, info_rank = rank_project(X, 1) error = abs(X[0, 1] - theta_hat) / theta_hat - RDG = abs(info_local["cost"] - info_sdp["cost"]) + RDG = abs(cost_local - info_sdp["cost"]) / cost_local print( f"{label}, theta sdp:{X[0,1]}, theta rounded:{x_round[1]} EVR:{info_rank['EVR']}, cost:{info_sdp['cost']}" ) diff --git a/_test/test_solvers.py b/_test/test_solvers.py index bb2c2cd..e8d019e 100644 --- a/_test/test_solvers.py +++ b/_test/test_solvers.py @@ -167,7 +167,7 @@ def test_solvers(n_seeds=1, noise=0.0): # test that we stay at real solution when initializing at it theta_gt = lifter.get_vec_around_gt(delta=0) try: - theta_hat, msg, cost_solver = lifter.local_solver(theta_gt, y) + theta_hat, info, cost_solver = lifter.local_solver(theta_gt, y) print("local solution:", theta_hat, f"cost: {cost_solver:.4e}") print("ground truth: ", theta_gt) except NotImplementedError: @@ -195,7 +195,7 @@ def test_solvers(n_seeds=1, noise=0.0): # test that we converge to real solution when initializing around it theta_0 = lifter.get_vec_around_gt(delta=NOISE) - theta_hat, msg, cost_solver = lifter.local_solver(theta_0, y) + theta_hat, info, cost_solver = lifter.local_solver(theta_0, y) print("init: ", theta_0) print("local solution:", theta_hat, f"cost: {cost_solver:.4e}") diff --git a/lifters/stereo1d_lifter.py b/lifters/stereo1d_lifter.py index d6c4e7d..64fd030 100644 --- a/lifters/stereo1d_lifter.py +++ b/lifters/stereo1d_lifter.py @@ -186,12 +186,12 @@ def local_solver( if np.abs(dx) < eps: msg = f"converged dx after {i} it" info = {"msg": msg, "cost": self.get_cost(x_op, y)} - return x_op, info + return x_op, info, info["cost"] else: msg = f"converged in du after {i} it" info = {"msg": msg, "cost": self.get_cost(x_op, y)} - return x_op, info - return None, {"msg": "didn't converge", "cost": None} + return x_op, info, info["cost"] + return None, {"msg": "didn't converge", "cost": None}, None def __repr__(self): return f"stereo1d_{self.param_level}"