From 0f05c02f567b406694c6527ae6f605ef8a461994 Mon Sep 17 00:00:00 2001 From: Raktim Mukhopadhyay Date: Tue, 4 Jun 2024 01:36:46 -0400 Subject: [PATCH] Update test_pkbc.py --- tests/test_pkbc.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_pkbc.py b/tests/test_pkbc.py index bd67736..4f1590f 100644 --- a/tests/test_pkbc.py +++ b/tests/test_pkbc.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd import plotly.graph_objects as go +import matplotlib.pyplot as plt from QuadratiK.spherical_clustering import PKBC, PKBD @@ -67,6 +68,7 @@ def test_pkbc(self): ) self.assertIsInstance(pkbd_cluster_fit_numpy.plot(3), type(go.Figure())) + self.assertIsInstance(pkbd_cluster_fit_numpy.plot(3, y_true), type(go.Figure())) self.assertIsInstance(pkbd_cluster_fit_numpy.summary(), str) with self.assertRaises(Exception): @@ -90,6 +92,18 @@ def test_pkbc(self): with self.assertRaises(Exception): PKBC(num_clust=3, stopping_rule="some").fit(data) + with self.assertRaises(Exception): + PKBC(num_clust=pd.DataFrame(3)).fit(data) + with self.assertRaises(ValueError): X = np.random.randn(10, 2) pkbd_cluster_fit_numpy.predict(X, 3) + + pkbd = PKBD() + x1 = pkbd.rpkb(100, np.array([1, 0]), 0.8, "rejvmf", random_state=42) + x2 = pkbd.rpkb(100, np.array([0, 1]), 0.8, "rejacg", random_state=42) + data_two = np.concatenate((x1, x2), axis=0) + pkbd_cluster_two_dim = PKBC(num_clust=3, random_state=42).fit(data_two) + y_true = pd.DataFrame(np.repeat(np.arange(1, 3), repeats=100)) + self.assertIsInstance(pkbd_cluster_two_dim.plot(3), type(plt.figure())) + self.assertIsInstance(pkbd_cluster_two_dim.plot(3, y_true), type(plt.figure()))