Skip to content

Commit

Permalink
Turned Gumbel off as default (mostly for pytest right now)
Browse files Browse the repository at this point in the history
  • Loading branch information
lisusdaniil committed Jul 9, 2024
1 parent b16a577 commit 37f14c4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion config/dICP_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dICP:


functionality:
gumbel: True # If true, use Gumbel-Softmax trick for nearest neighour
gumbel: False # If true, use Gumbel-Softmax trick for nearest neighour
# Not yet implemented
# svd: False # If true, use SVD to solve pt2pt problem, no effect for pt2pl

Expand Down
16 changes: 6 additions & 10 deletions tests/test_ICP.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
def max_iterations():
return 100

@pytest.fixture
def diff_tolerance():
return 0.3

@pytest.fixture
def tolerance():
return 1e-10
Expand All @@ -36,7 +32,7 @@ def target():
data_file_path = os.path.join(current_dir, 'data', 'points_map.npy')
return np.load(data_file_path)

def test_pt2pt_dICP(source, target, max_iterations, diff_tolerance, tolerance):
def test_pt2pt_dICP(source, target, max_iterations, tolerance):
"""
Test differentiable point-to-point ICP algorithm.
"""
Expand Down Expand Up @@ -67,10 +63,10 @@ def test_pt2pt_dICP(source, target, max_iterations, diff_tolerance, tolerance):

# Check that the transformation is correct
err_T = se3op.tran2vec(T_ts_true @ np.linalg.inv(T_ts_pred.detach().numpy()))
assert(np.linalg.norm(err_T) < diff_tolerance)
assert(np.linalg.norm(err_T) < tolerance)

# Check that the transformed source is close to target
assert np.allclose(source_transformed.detach().numpy(), target.detach().numpy(), atol=diff_tolerance)
assert np.allclose(source_transformed.detach().numpy(), target.detach().numpy(), atol=1e-5)

# Check that the gradient is not none
T_ts_pred.sum().backward()
Expand All @@ -81,7 +77,7 @@ def test_pt2pt_dICP(source, target, max_iterations, diff_tolerance, tolerance):
# Confirm gradient is not nan
assert torch.isnan(source.grad).any() == False and torch.isnan(target.grad).any() == False

def test_pt2pl_dICP(source, target, max_iterations, diff_tolerance, tolerance):
def test_pt2pl_dICP(source, target, max_iterations, tolerance):
"""
Test differentiable point-to-plane ICP algorithm.
"""
Expand Down Expand Up @@ -110,10 +106,10 @@ def test_pt2pl_dICP(source, target, max_iterations, diff_tolerance, tolerance):

# Check that the transformation is correct
err_T = se3op.tran2vec(T_ts_true @ np.linalg.inv(T_ts_pred.detach().numpy()))
assert(np.linalg.norm(err_T) < diff_tolerance)
assert(np.linalg.norm(err_T) < tolerance)

# Check that the transformed source is close to target
assert np.allclose(source_transformed.detach().numpy(), target[:,:3].detach().numpy(), atol=diff_tolerance)
assert np.allclose(source_transformed.detach().numpy(), target[:,:3].detach().numpy(), atol=1e-5)

# Check that the gradient is not none
T_ts_pred.sum().backward()
Expand Down

0 comments on commit 37f14c4

Please sign in to comment.