Skip to content

Commit

Permalink
Fix stereo1D test
Browse files Browse the repository at this point in the history
  • Loading branch information
duembgen committed Mar 13, 2024
1 parent 3dc99b3 commit 92cd04f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions _scripts/run_stereo_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,19 +222,19 @@ def run_stereo_1d():

print("theta gt:", lifter.theta)

theta_hat, info_local = lifter.local_solver(t_init=lifter.theta, y=y)
print("theta global:", theta_hat, "cost:", info_local["cost"])
theta_hat, info_local, cost_local = lifter.local_solver(t_init=lifter.theta, y=y)
print("theta global:", theta_hat, "cost:", cost_local)

# sanity check
x_hat = lifter.get_x(theta=theta_hat)
assert abs(x_hat.T @ Q @ x_hat - info_local["cost"]) / info_local["cost"] < 1e-10
assert abs(x_hat.T @ Q @ x_hat - cost_local) / cost_local < 1e-10
for A_list, label in zip([A_known, A_all], ["known", "learned"]):
lifter.test_constraints(A_list)
Constraints = [(lifter.get_A0(), 1.0)] + [(Ai, 0.0) for Ai in A_list]
X, info_sdp = solve_sdp_cvxpy(Q=Q, Constraints=Constraints, verbose=False)
x_round, info_rank = rank_project(X, 1)
error = abs(X[0, 1] - theta_hat) / theta_hat
RDG = abs(info_local["cost"] - info_sdp["cost"])
RDG = abs(cost_local - info_sdp["cost"]) / cost_local
print(
f"{label}, theta sdp:{X[0,1]}, theta rounded:{x_round[1]} EVR:{info_rank['EVR']}, cost:{info_sdp['cost']}"
)
Expand Down
4 changes: 2 additions & 2 deletions _test/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_solvers(n_seeds=1, noise=0.0):
# test that we stay at real solution when initializing at it
theta_gt = lifter.get_vec_around_gt(delta=0)
try:
theta_hat, msg, cost_solver = lifter.local_solver(theta_gt, y)
theta_hat, info, cost_solver = lifter.local_solver(theta_gt, y)
print("local solution:", theta_hat, f"cost: {cost_solver:.4e}")
print("ground truth: ", theta_gt)
except NotImplementedError:
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_solvers(n_seeds=1, noise=0.0):

# test that we converge to real solution when initializing around it
theta_0 = lifter.get_vec_around_gt(delta=NOISE)
theta_hat, msg, cost_solver = lifter.local_solver(theta_0, y)
theta_hat, info, cost_solver = lifter.local_solver(theta_0, y)

print("init: ", theta_0)
print("local solution:", theta_hat, f"cost: {cost_solver:.4e}")
Expand Down
6 changes: 3 additions & 3 deletions lifters/stereo1d_lifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ def local_solver(
if np.abs(dx) < eps:
msg = f"converged dx after {i} it"
info = {"msg": msg, "cost": self.get_cost(x_op, y)}
return x_op, info
return x_op, info, info["cost"]
else:
msg = f"converged in du after {i} it"
info = {"msg": msg, "cost": self.get_cost(x_op, y)}
return x_op, info
return None, {"msg": "didn't converge", "cost": None}
return x_op, info, info["cost"]
return None, {"msg": "didn't converge", "cost": None}, None

def __repr__(self):
return f"stereo1d_{self.param_level}"

0 comments on commit 92cd04f

Please sign in to comment.