From 829cc571e26cdc3f22576442c3842c2d4c67be54 Mon Sep 17 00:00:00 2001 From: Julien Veyssier Date: Wed, 20 Nov 2024 18:20:41 +0100 Subject: [PATCH] adjust title generation using the same logic as message generation Signed-off-by: Julien Veyssier --- appinfo/routes.php | 2 +- lib/Controller/ChattyLLMController.php | 37 +++++--- lib/Db/ChattyLLM/Session.php | 2 +- lib/Db/ChattyLLM/SessionMapper.php | 21 +++++ lib/Listener/ChattyLLMTaskListener.php | 41 +++++---- .../ChattyLLM/ChattyLLMInputForm.vue | 90 +++++++++++++------ 6 files changed, 133 insertions(+), 60 deletions(-) diff --git a/appinfo/routes.php b/appinfo/routes.php index 3488d40f..7588993c 100644 --- a/appinfo/routes.php +++ b/appinfo/routes.php @@ -24,7 +24,7 @@ ['name' => 'chattyLLM#getMessages', 'url' => '/chat/messages', 'verb' => 'GET'], ['name' => 'chattyLLM#generateForSession', 'url' => '/chat/generate', 'verb' => 'GET'], ['name' => 'chattyLLM#regenerateForSession', 'url' => '/chat/regenerate', 'verb' => 'GET'], - ['name' => 'chattyLLM#checkMessageGenerationSession', 'url' => '/chat/check_session', 'verb' => 'GET'], + ['name' => 'chattyLLM#checkSession', 'url' => '/chat/check_session', 'verb' => 'GET'], ['name' => 'chattyLLM#checkMessageGenerationTask', 'url' => '/chat/check_generation', 'verb' => 'GET'], ['name' => 'chattyLLM#generateTitle', 'url' => '/chat/generate_title', 'verb' => 'GET'], ['name' => 'chattyLLM#checkTitleGenerationTask', 'url' => '/chat/check_title_generation', 'verb' => 'GET'], diff --git a/lib/Controller/ChattyLLMController.php b/lib/Controller/ChattyLLMController.php index 071211dd..e7e87f79 100644 --- a/lib/Controller/ChattyLLMController.php +++ b/lib/Controller/ChattyLLMController.php @@ -403,7 +403,7 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes * @throws \OCP\DB\Exception */ #[NoAdminRequired] - public function checkMessageGenerationSession(int $sessionId): JSONResponse { + public function checkSession(int $sessionId): JSONResponse { if ($this->userId === null) { return new JSONResponse(['error' => $this->l10n->t('User not logged in')], Http::STATUS_UNAUTHORIZED); } @@ -414,22 +414,32 @@ public function checkMessageGenerationSession(int $sessionId): JSONResponse { } try { - $tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId); + $messageTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId); + $titleTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-title:' . $sessionId); } catch (\OCP\TaskProcessing\Exception\Exception $e) { return new JSONResponse(['error' => 'task_query_failed'], Http::STATUS_BAD_REQUEST); } - $tasks = array_filter($tasks, static function (Task $task) { + $messageTasks = array_filter($messageTasks, static function (Task $task) { return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; }); - if (empty($tasks)) { - return new JSONResponse([ - 'taskId' => null, - ]); - } - $task = array_pop($tasks); - return new JSONResponse([ - 'taskId' => $task->getId(), - ]); + $titleTasks = array_filter($titleTasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + $session = $this->sessionMapper->getUserSession($this->userId, $sessionId); + $responseData = [ + 'messageTaskId' => null, + 'titleTaskId' => null, + 'sessionTitle' => $session->getTitle(), + ]; + if (!empty($messageTasks)) { + $task = array_pop($messageTasks); + $responseData['messageTaskId'] = $task->getId(); + } + if (!empty($titleTasks)) { + $task = array_pop($titleTasks); + $responseData['titleTaskId'] = $task->getId(); + } + return new JSONResponse($responseData); } /** @@ -523,8 +533,7 @@ public function checkTitleGenerationTask(int $taskId, int $sessionId): JSONRespo $title = str_replace('"', '', $title); $title = explode(PHP_EOL, $title)[0]; $title = trim($title); - - $this->sessionMapper->updateSessionTitle($this->userId, $sessionId, $title); + // do not write the title here since it's done in the listener return new JSONResponse(['result' => $title]); } catch (\OCP\DB\Exception $e) { diff --git a/lib/Db/ChattyLLM/Session.php b/lib/Db/ChattyLLM/Session.php index a6c19a1a..14a5e7c0 100644 --- a/lib/Db/ChattyLLM/Session.php +++ b/lib/Db/ChattyLLM/Session.php @@ -31,7 +31,7 @@ /** * @method \string getUserId() * @method \void setUserId(string $userId) - * @method \?string getTitle() + * @method \string|null getTitle() * @method \void setTitle(?string $title) * @method \int|null getTimestamp() * @method \void setTimestamp(?int $timestamp) diff --git a/lib/Db/ChattyLLM/SessionMapper.php b/lib/Db/ChattyLLM/SessionMapper.php index 60418d03..276fca67 100644 --- a/lib/Db/ChattyLLM/SessionMapper.php +++ b/lib/Db/ChattyLLM/SessionMapper.php @@ -25,7 +25,10 @@ namespace OCA\Assistant\Db\ChattyLLM; +use OCP\AppFramework\Db\DoesNotExistException; +use OCP\AppFramework\Db\MultipleObjectsReturnedException; use OCP\AppFramework\Db\QBMapper; +use OCP\DB\Exception; use OCP\DB\QueryBuilder\IQueryBuilder; use OCP\IDBConnection; @@ -59,6 +62,24 @@ public function exists(string $userId, int $sessionId): bool { } } + /** + * @param string $userId + * @param int $sessionId + * @return Session + * @throws DoesNotExistException + * @throws MultipleObjectsReturnedException + * @throws Exception + */ + public function getUserSession(string $userId, int $sessionId): Session { + $qb = $this->db->getQueryBuilder(); + $qb->select('id', 'title', 'timestamp') + ->from($this->getTableName()) + ->where($qb->expr()->eq('id', $qb->createPositionalParameter($sessionId, IQueryBuilder::PARAM_INT))) + ->andWhere($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId, IQueryBuilder::PARAM_STR))); + + return $this->findEntity($qb); + } + /** * @param string $userId * @return array diff --git a/lib/Listener/ChattyLLMTaskListener.php b/lib/Listener/ChattyLLMTaskListener.php index a5fb1ea8..b79deade 100644 --- a/lib/Listener/ChattyLLMTaskListener.php +++ b/lib/Listener/ChattyLLMTaskListener.php @@ -7,6 +7,7 @@ use OCA\Assistant\AppInfo\Application; use OCA\Assistant\Db\ChattyLLM\Message; use OCA\Assistant\Db\ChattyLLM\MessageMapper; +use OCA\Assistant\Db\ChattyLLM\SessionMapper; use OCP\EventDispatcher\Event; use OCP\EventDispatcher\IEventListener; use OCP\TaskProcessing\Events\TaskSuccessfulEvent; @@ -19,6 +20,7 @@ class ChattyLLMTaskListener implements IEventListener { public function __construct( private MessageMapper $messageMapper, + private SessionMapper $sessionMapper, private LoggerInterface $logger, ) { } @@ -31,23 +33,32 @@ public function handle(Event $event): void { $task = $event->getTask(); $customId = $task->getCustomId(); $appId = $task->getAppId(); - if ($appId !== (Application::APP_ID . ':chatty-llm') - || $customId === null - || !preg_match('/^chatty-llm:(\d+)$/', $customId, $matches) - ) { + + if ($customId === null || $appId !== (Application::APP_ID . ':chatty-llm')) { return; } - $sessionId = (int)$matches[1]; - - $message = new Message(); - $message->setSessionId($sessionId); - $message->setRole('assistant'); - $message->setContent(trim($task->getOutput()['output'] ?? '')); - $message->setTimestamp(time()); - try { - $this->messageMapper->insert($message); - } catch (\OCP\DB\Exception $e) { - $this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]); + + // title generation + if (preg_match('/^chatty-title:(\d+)$/', $customId, $matches)) { + $sessionId = (int)$matches[1]; + $title = trim($task->getOutput()['output'] ?? ''); + $this->sessionMapper->updateSessionTitle($task->getUserId(), $sessionId, $title); + } + + // message generation + if (preg_match('/^chatty-llm:(\d+)$/', $customId, $matches)) { + $sessionId = (int)$matches[1]; + + $message = new Message(); + $message->setSessionId($sessionId); + $message->setRole('assistant'); + $message->setContent(trim($task->getOutput()['output'] ?? '')); + $message->setTimestamp(time()); + try { + $this->messageMapper->insert($message); + } catch (\OCP\DB\Exception $e) { + $this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]); + } } } } diff --git a/src/components/ChattyLLM/ChattyLLMInputForm.vue b/src/components/ChattyLLM/ChattyLLMInputForm.vue index 6dd099e5..6373a041 100644 --- a/src/components/ChattyLLM/ChattyLLMInputForm.vue +++ b/src/components/ChattyLLM/ChattyLLMInputForm.vue @@ -220,6 +220,7 @@ export default { watch: { async active() { this.loading.llmGeneration = false + this.loading.titleGeneration = false this.allMessagesLoaded = false this.chatContent = '' this.msgCursor = 0 @@ -238,22 +239,45 @@ export default { } // start polling in case a message is currently being generated try { - const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId: this.active.id } }) - console.debug('check session response:', checkSessionResponse) - if (checkSessionResponse.data.taskId === null) { - return + const sessionId = this.active.id + const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId } }) + if (checkSessionResponse.data?.sessionTitle && checkSessionResponse.data?.sessionTitle !== this.active.title) { + this.active.title = checkSessionResponse.data?.sessionTitle + console.debug('update session title with check result') } - try { - this.loading.llmGeneration = true - const message = await this.pollGenerationTask(checkSessionResponse.data.taskId, this.active.id) - console.debug('checkTaskPolling result:', message) - this.messages.push(message) - this.scrollToBottom() - } catch (error) { - console.error('checkGenerationTask error:', error) - showError(t('assistant', 'Error generating a response')) - } finally { - this.loading.llmGeneration = false + console.debug('check session response:', checkSessionResponse.data) + if (checkSessionResponse.data.messageTaskId !== null) { + try { + this.loading.llmGeneration = true + const message = await this.pollGenerationTask(checkSessionResponse.data.messageTaskId, sessionId) + console.debug('checkTaskPolling result:', message) + this.messages.push(message) + this.scrollToBottom() + } catch (error) { + console.error('checkGenerationTask error:', error) + showError(t('assistant', 'Error generating a response')) + } finally { + this.loading.llmGeneration = false + } + } else if (checkSessionResponse.data.titleTaskId !== null) { + try { + this.loading.titleGeneration = true + const titleResponse = await this.pollTitleGenerationTask(checkSessionResponse.data.titleTaskId, sessionId) + console.debug('checkTaskPolling result:', titleResponse) + if (titleResponse?.data?.result == null) { + throw new Error('No title generated, response:', titleResponse) + } + + const session = this.sessions.find(s => s.id === sessionId) + if (session) { + session.title = titleResponse?.data?.result + } + } catch (error) { + console.error('onCheckSessionTitle error:', error) + showError(error?.response?.data?.error ?? t('assistant', 'Error getting the generated title for the conversation')) + } finally { + this.loading.titleGeneration = false + } } } catch (error) { console.error('check session error:', error) @@ -367,18 +391,17 @@ export default { async onGenerateSessionTitle() { try { this.loading.titleGeneration = true - const response = await axios.get(getChatURL('/generate_title'), { params: { sessionId: this.active.id } }) - const titleResponse = await this.pollTitleGenerationTask(response.data.taskId) + const sessionId = this.active.id + const response = await axios.get(getChatURL('/generate_title'), { params: { sessionId } }) + const titleResponse = await this.pollTitleGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', titleResponse) if (titleResponse?.data?.result == null) { throw new Error('No title generated, response:', response) } - for (const session of this.sessions) { - if (session.id === this.active.id) { - session.title = titleResponse?.data?.result - break - } + const session = this.sessions.find(s => s.id === sessionId) + if (session) { + session.title = titleResponse?.data?.result } } catch (error) { console.error('onGenerateSessionTitle error:', error) @@ -592,12 +615,12 @@ export default { if (sessionId === this.active.id) { resolve(response.data) } else { - console.debug('Ignoring received a message for session ' + sessionId + ' that is not selected anymore') + console.debug('Ignoring received message for session ' + sessionId + ' that is not selected anymore') // should we reject here? } }).catch(error => { if (sessionId !== this.active.id) { - console.debug('Stop polling session ' + sessionId + ' because it is not selected anymore') + console.debug('Stop polling messages for session ' + sessionId + ' because it is not selected anymore') clearInterval(this.pollMessageGenerationTimerId) } // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) @@ -606,30 +629,39 @@ export default { clearInterval(this.pollMessageGenerationTimerId) reject(new Error('Message generation task check failed')) } else { - console.debug('checkTaskPolling, task is still scheduled or running', error) + console.debug('checkTaskPolling, task is still scheduled or running') } }) }, 2000) }) }, - async pollTitleGenerationTask(taskId) { + async pollTitleGenerationTask(taskId, sessionId) { return new Promise((resolve, reject) => { this.pollTitleGenerationTimerId = setInterval(() => { axios.get( getChatURL('/check_title_generation'), - { params: { taskId, sessionId: this.active.id } }, + { params: { taskId, sessionId } }, ).then(response => { + if (sessionId === this.active.id) { + resolve(response) + } else { + console.debug('Ignoring received title for session ' + sessionId + ' that is not selected anymore') + // should we reject here? + } clearInterval(this.pollTitleGenerationTimerId) - resolve(response) }).catch(error => { + if (sessionId !== this.active.id) { + console.debug('Stop polling title for session ' + sessionId + ' because it is not selected anymore') + clearInterval(this.pollTitleGenerationTimerId) + } // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) if (error.response?.status !== 417) { console.error('checkTaskPolling error', error) clearInterval(this.pollTitleGenerationTimerId) reject(new Error('Title generation task check failed')) } else { - console.debug('checkTaskPolling, task is still scheduled or running', error) + console.debug('checkTaskPolling, task is still scheduled or running') } }) }, 2000)