From 06201f6ce2536dd09bd5a74ea4b83ef02e297f4d Mon Sep 17 00:00:00 2001 From: Julien Veyssier Date: Wed, 20 Nov 2024 16:29:10 +0100 Subject: [PATCH 1/3] fix polling when switching sessions in the frontend, prevent scheduling multiple llm tasks for one session Signed-off-by: Julien Veyssier --- appinfo/routes.php | 1 + lib/AppInfo/Application.php | 2 + lib/Controller/ChattyLLMController.php | 76 +++++++++++++++++-- lib/Listener/ChattyLLMTaskListener.php | 53 +++++++++++++ .../ChattyLLM/ChattyLLMInputForm.vue | 63 +++++++++++---- tests/psalm-baseline.xml | 4 +- 6 files changed, 178 insertions(+), 21 deletions(-) create mode 100644 lib/Listener/ChattyLLMTaskListener.php diff --git a/appinfo/routes.php b/appinfo/routes.php index 19443cf5..3488d40f 100644 --- a/appinfo/routes.php +++ b/appinfo/routes.php @@ -24,6 +24,7 @@ ['name' => 'chattyLLM#getMessages', 'url' => '/chat/messages', 'verb' => 'GET'], ['name' => 'chattyLLM#generateForSession', 'url' => '/chat/generate', 'verb' => 'GET'], ['name' => 'chattyLLM#regenerateForSession', 'url' => '/chat/regenerate', 'verb' => 'GET'], + ['name' => 'chattyLLM#checkMessageGenerationSession', 'url' => '/chat/check_session', 'verb' => 'GET'], ['name' => 'chattyLLM#checkMessageGenerationTask', 'url' => '/chat/check_generation', 'verb' => 'GET'], ['name' => 'chattyLLM#generateTitle', 'url' => '/chat/generate_title', 'verb' => 'GET'], ['name' => 'chattyLLM#checkTitleGenerationTask', 'url' => '/chat/check_title_generation', 'verb' => 'GET'], diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index 8005fb48..c0a0ff3a 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -4,6 +4,7 @@ use OCA\Assistant\Capabilities; use OCA\Assistant\Listener\BeforeTemplateRenderedListener; +use OCA\Assistant\Listener\ChattyLLMTaskListener; use OCA\Assistant\Listener\CSPListener; use OCA\Assistant\Listener\FreePrompt\FreePromptReferenceListener; use OCA\Assistant\Listener\SpeechToText\SpeechToTextReferenceListener; @@ -55,6 +56,7 @@ public function register(IRegistrationContext $context): void { $context->registerEventListener(TaskSuccessfulEvent::class, TaskSuccessfulListener::class); $context->registerEventListener(TaskFailedEvent::class, TaskFailedListener::class); + $context->registerEventListener(TaskSuccessfulEvent::class, ChattyLLMTaskListener::class); $context->registerNotifierService(Notifier::class); diff --git a/lib/Controller/ChattyLLMController.php b/lib/Controller/ChattyLLMController.php index 32183246..071211dd 100644 --- a/lib/Controller/ChattyLLMController.php +++ b/lib/Controller/ChattyLLMController.php @@ -297,7 +297,11 @@ public function generateForSession(int $sessionId): JSONResponse { . PHP_EOL . 'assistant: '; - $taskId = $this->scheduleLLMTask($stichedPrompt); + try { + $taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId); + } catch (\Exception $e) { + return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST); + } return new JSONResponse(['taskId' => $taskId]); } @@ -374,7 +378,7 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes $message->setRole('assistant'); $message->setContent(trim($task->getOutput()['output'] ?? '')); $message->setTimestamp(time()); - $this->messageMapper->insert($message); + // do not insert here, it is done by the listener return new JSONResponse($message); } catch (\OCP\DB\Exception $e) { $this->logger->warning('Failed to add a chat message into DB', ['exception' => $e]); @@ -388,6 +392,46 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes return new JSONResponse(['error' => 'unknown_error', 'task_status' => $task->getstatus()], Http::STATUS_BAD_REQUEST); } + /** + * Check the status of a session + * + * Used by the frontend to determine if it should poll a generation task status. + * + * @param int $sessionId + * @return JSONResponse + * @throws \JsonException + * @throws \OCP\DB\Exception + */ + #[NoAdminRequired] + public function checkMessageGenerationSession(int $sessionId): JSONResponse { + if ($this->userId === null) { + return new JSONResponse(['error' => $this->l10n->t('User not logged in')], Http::STATUS_UNAUTHORIZED); + } + + $sessionExists = $this->sessionMapper->exists($this->userId, $sessionId); + if (!$sessionExists) { + return new JSONResponse(['error' => $this->l10n->t('Session not found')], Http::STATUS_NOT_FOUND); + } + + try { + $tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId); + } catch (\OCP\TaskProcessing\Exception\Exception $e) { + return new JSONResponse(['error' => 'task_query_failed'], Http::STATUS_BAD_REQUEST); + } + $tasks = array_filter($tasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + if (empty($tasks)) { + return new JSONResponse([ + 'taskId' => null, + ]); + } + $task = array_pop($tasks); + return new JSONResponse([ + 'taskId' => $task->getId(), + ]); + } + /** * Schedule a task to generate a title for the chat session * @@ -430,7 +474,11 @@ public function generateTitle(int $sessionId): JSONResponse { . PHP_EOL . PHP_EOL . $userInstructions; - $taskId = $this->scheduleLLMTask($stichedPrompt); + try { + $taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId, false); + } catch (\Exception $e) { + return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST); + } return new JSONResponse(['taskId' => $taskId]); } catch (\OCP\DB\Exception $e) { $this->logger->warning('Failed to generate a title for the chat session', ['exception' => $e]); @@ -525,14 +573,32 @@ private function getStichedMessages(int $sessionId): string { * Schedule the LLM task * * @param string $content + * @param int $sessionId + * @param bool $isMessage * @return int|null * @throws Exception * @throws PreConditionNotMetException * @throws UnauthorizedException * @throws ValidationException + * @throws \JsonException */ - private function scheduleLLMTask(string $content): ?int { - $task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId); + private function scheduleLLMTask(string $content, int $sessionId, bool $isMessage = true): ?int { + $customId = ($isMessage + ? 'chatty-llm:' + : 'chatty-title:') . $sessionId; + try { + $tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', $customId); + } catch (\OCP\TaskProcessing\Exception\Exception $e) { + throw new \Exception('task_query_failed'); + } + $tasks = array_filter($tasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + // prevent scheduling multiple llm tasks simultaneously for one session + if (!empty($tasks)) { + throw new \Exception('session_already_thinking'); + } + $task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId); $this->taskProcessingManager->scheduleTask($task); return $task->getId(); } diff --git a/lib/Listener/ChattyLLMTaskListener.php b/lib/Listener/ChattyLLMTaskListener.php new file mode 100644 index 00000000..a5fb1ea8 --- /dev/null +++ b/lib/Listener/ChattyLLMTaskListener.php @@ -0,0 +1,53 @@ + + */ +class ChattyLLMTaskListener implements IEventListener { + + public function __construct( + private MessageMapper $messageMapper, + private LoggerInterface $logger, + ) { + } + + public function handle(Event $event): void { + if (!($event instanceof TaskSuccessfulEvent)) { + return; + } + + $task = $event->getTask(); + $customId = $task->getCustomId(); + $appId = $task->getAppId(); + if ($appId !== (Application::APP_ID . ':chatty-llm') + || $customId === null + || !preg_match('/^chatty-llm:(\d+)$/', $customId, $matches) + ) { + return; + } + $sessionId = (int)$matches[1]; + + $message = new Message(); + $message->setSessionId($sessionId); + $message->setRole('assistant'); + $message->setContent(trim($task->getOutput()['output'] ?? '')); + $message->setTimestamp(time()); + try { + $this->messageMapper->insert($message); + } catch (\OCP\DB\Exception $e) { + $this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]); + } + } +} diff --git a/src/components/ChattyLLM/ChattyLLMInputForm.vue b/src/components/ChattyLLM/ChattyLLMInputForm.vue index 9c33964a..6dd099e5 100644 --- a/src/components/ChattyLLM/ChattyLLMInputForm.vue +++ b/src/components/ChattyLLM/ChattyLLMInputForm.vue @@ -219,6 +219,7 @@ export default { watch: { async active() { + this.loading.llmGeneration = false this.allMessagesLoaded = false this.chatContent = '' this.msgCursor = 0 @@ -226,13 +227,37 @@ export default { this.editingTitle = false this.$refs.inputComponent.focus() - if (this.active != null && !this.loading.newSession) { + if (this.active !== null && !this.loading.newSession) { await this.fetchMessages() this.scrollToBottom() } else { // when no active session or creating a new session this.allMessagesLoaded = true this.loading.newSession = false + return + } + // start polling in case a message is currently being generated + try { + const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId: this.active.id } }) + console.debug('check session response:', checkSessionResponse) + if (checkSessionResponse.data.taskId === null) { + return + } + try { + this.loading.llmGeneration = true + const message = await this.pollGenerationTask(checkSessionResponse.data.taskId, this.active.id) + console.debug('checkTaskPolling result:', message) + this.messages.push(message) + this.scrollToBottom() + } catch (error) { + console.error('checkGenerationTask error:', error) + showError(t('assistant', 'Error generating a response')) + } finally { + this.loading.llmGeneration = false + } + } catch (error) { + console.error('check session error:', error) + showError(t('assistant', 'Error checking if the session is thinking')) } }, }, @@ -328,7 +353,7 @@ export default { this.messages.push({ role, content, timestamp }) this.chatContent = '' this.scrollToBottom() - await this.newMessage(role, content, timestamp) + await this.newMessage(role, content, timestamp, this.active.id) }, onLoadOlderMessages() { @@ -466,13 +491,13 @@ export default { } }, - async newMessage(role, content, timestamp) { + async newMessage(role, content, timestamp, sessionId) { try { this.loading.newHumanMessage = true const firstHumanMessage = this.messages.length === 1 && this.messages[0].role === Roles.HUMAN const response = await axios.put(getChatURL('/new_message'), { - sessionId: this.active.id, + sessionId, role, content, timestamp, @@ -485,11 +510,11 @@ export default { this.messages[this.messages.length - 1] = response.data if (firstHumanMessage) { - const session = this.sessions.find((session) => session.id === this.active.id) + const session = this.sessions.find((session) => session.id === sessionId) session.title = content } - await this.runGenerationTask() + await this.runGenerationTask(sessionId) } catch (error) { this.loading.newHumanMessage = false console.error('newMessage error:', error) @@ -521,12 +546,12 @@ export default { } }, - async runGenerationTask() { + async runGenerationTask(sessionId) { try { this.loading.llmGeneration = true - const response = await axios.get(getChatURL('/generate'), { params: { sessionId: this.active.id } }) + const response = await axios.get(getChatURL('/generate'), { params: { sessionId } }) console.debug('scheduleGenerationTask response:', response) - const message = await this.pollGenerationTask(response.data.taskId) + const message = await this.pollGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', message) this.messages.push(message) this.scrollToBottom() @@ -540,10 +565,11 @@ export default { async runRegenerationTask(messageId) { try { + const sessionId = this.active.id this.loading.llmGeneration = true - const response = await axios.get(getChatURL('/regenerate'), { params: { messageId, sessionId: this.active.id } }) + const response = await axios.get(getChatURL('/regenerate'), { params: { messageId, sessionId } }) console.debug('scheduleRegenerationTask response:', response) - const message = await this.pollGenerationTask(response.data.taskId) + const message = await this.pollGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', message) this.messages[this.messages.length - 1] = message this.scrollToBottom() @@ -555,16 +581,25 @@ export default { } }, - async pollGenerationTask(taskId) { + async pollGenerationTask(taskId, sessionId) { return new Promise((resolve, reject) => { this.pollMessageGenerationTimerId = setInterval(() => { axios.get( getChatURL('/check_generation'), - { params: { taskId, sessionId: this.active.id } }, + { params: { taskId, sessionId } }, ).then(response => { clearInterval(this.pollMessageGenerationTimerId) - resolve(response.data) + if (sessionId === this.active.id) { + resolve(response.data) + } else { + console.debug('Ignoring received a message for session ' + sessionId + ' that is not selected anymore') + // should we reject here? + } }).catch(error => { + if (sessionId !== this.active.id) { + console.debug('Stop polling session ' + sessionId + ' because it is not selected anymore') + clearInterval(this.pollMessageGenerationTimerId) + } // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) if (error.response?.status !== 417) { console.error('checkTaskPolling error', error) diff --git a/tests/psalm-baseline.xml b/tests/psalm-baseline.xml index 1b3c8054..a1dc3d18 100644 --- a/tests/psalm-baseline.xml +++ b/tests/psalm-baseline.xml @@ -1,5 +1,5 @@ - + - $content], Application::APP_ID . ':chatty-llm', $this->userId)]]> + $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId)]]> From 829cc571e26cdc3f22576442c3842c2d4c67be54 Mon Sep 17 00:00:00 2001 From: Julien Veyssier Date: Wed, 20 Nov 2024 18:20:41 +0100 Subject: [PATCH 2/3] adjust title generation using the same logic as message generation Signed-off-by: Julien Veyssier --- appinfo/routes.php | 2 +- lib/Controller/ChattyLLMController.php | 37 +++++--- lib/Db/ChattyLLM/Session.php | 2 +- lib/Db/ChattyLLM/SessionMapper.php | 21 +++++ lib/Listener/ChattyLLMTaskListener.php | 41 +++++---- .../ChattyLLM/ChattyLLMInputForm.vue | 90 +++++++++++++------ 6 files changed, 133 insertions(+), 60 deletions(-) diff --git a/appinfo/routes.php b/appinfo/routes.php index 3488d40f..7588993c 100644 --- a/appinfo/routes.php +++ b/appinfo/routes.php @@ -24,7 +24,7 @@ ['name' => 'chattyLLM#getMessages', 'url' => '/chat/messages', 'verb' => 'GET'], ['name' => 'chattyLLM#generateForSession', 'url' => '/chat/generate', 'verb' => 'GET'], ['name' => 'chattyLLM#regenerateForSession', 'url' => '/chat/regenerate', 'verb' => 'GET'], - ['name' => 'chattyLLM#checkMessageGenerationSession', 'url' => '/chat/check_session', 'verb' => 'GET'], + ['name' => 'chattyLLM#checkSession', 'url' => '/chat/check_session', 'verb' => 'GET'], ['name' => 'chattyLLM#checkMessageGenerationTask', 'url' => '/chat/check_generation', 'verb' => 'GET'], ['name' => 'chattyLLM#generateTitle', 'url' => '/chat/generate_title', 'verb' => 'GET'], ['name' => 'chattyLLM#checkTitleGenerationTask', 'url' => '/chat/check_title_generation', 'verb' => 'GET'], diff --git a/lib/Controller/ChattyLLMController.php b/lib/Controller/ChattyLLMController.php index 071211dd..e7e87f79 100644 --- a/lib/Controller/ChattyLLMController.php +++ b/lib/Controller/ChattyLLMController.php @@ -403,7 +403,7 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes * @throws \OCP\DB\Exception */ #[NoAdminRequired] - public function checkMessageGenerationSession(int $sessionId): JSONResponse { + public function checkSession(int $sessionId): JSONResponse { if ($this->userId === null) { return new JSONResponse(['error' => $this->l10n->t('User not logged in')], Http::STATUS_UNAUTHORIZED); } @@ -414,22 +414,32 @@ public function checkMessageGenerationSession(int $sessionId): JSONResponse { } try { - $tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId); + $messageTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId); + $titleTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-title:' . $sessionId); } catch (\OCP\TaskProcessing\Exception\Exception $e) { return new JSONResponse(['error' => 'task_query_failed'], Http::STATUS_BAD_REQUEST); } - $tasks = array_filter($tasks, static function (Task $task) { + $messageTasks = array_filter($messageTasks, static function (Task $task) { return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; }); - if (empty($tasks)) { - return new JSONResponse([ - 'taskId' => null, - ]); - } - $task = array_pop($tasks); - return new JSONResponse([ - 'taskId' => $task->getId(), - ]); + $titleTasks = array_filter($titleTasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + $session = $this->sessionMapper->getUserSession($this->userId, $sessionId); + $responseData = [ + 'messageTaskId' => null, + 'titleTaskId' => null, + 'sessionTitle' => $session->getTitle(), + ]; + if (!empty($messageTasks)) { + $task = array_pop($messageTasks); + $responseData['messageTaskId'] = $task->getId(); + } + if (!empty($titleTasks)) { + $task = array_pop($titleTasks); + $responseData['titleTaskId'] = $task->getId(); + } + return new JSONResponse($responseData); } /** @@ -523,8 +533,7 @@ public function checkTitleGenerationTask(int $taskId, int $sessionId): JSONRespo $title = str_replace('"', '', $title); $title = explode(PHP_EOL, $title)[0]; $title = trim($title); - - $this->sessionMapper->updateSessionTitle($this->userId, $sessionId, $title); + // do not write the title here since it's done in the listener return new JSONResponse(['result' => $title]); } catch (\OCP\DB\Exception $e) { diff --git a/lib/Db/ChattyLLM/Session.php b/lib/Db/ChattyLLM/Session.php index a6c19a1a..14a5e7c0 100644 --- a/lib/Db/ChattyLLM/Session.php +++ b/lib/Db/ChattyLLM/Session.php @@ -31,7 +31,7 @@ /** * @method \string getUserId() * @method \void setUserId(string $userId) - * @method \?string getTitle() + * @method \string|null getTitle() * @method \void setTitle(?string $title) * @method \int|null getTimestamp() * @method \void setTimestamp(?int $timestamp) diff --git a/lib/Db/ChattyLLM/SessionMapper.php b/lib/Db/ChattyLLM/SessionMapper.php index 60418d03..276fca67 100644 --- a/lib/Db/ChattyLLM/SessionMapper.php +++ b/lib/Db/ChattyLLM/SessionMapper.php @@ -25,7 +25,10 @@ namespace OCA\Assistant\Db\ChattyLLM; +use OCP\AppFramework\Db\DoesNotExistException; +use OCP\AppFramework\Db\MultipleObjectsReturnedException; use OCP\AppFramework\Db\QBMapper; +use OCP\DB\Exception; use OCP\DB\QueryBuilder\IQueryBuilder; use OCP\IDBConnection; @@ -59,6 +62,24 @@ public function exists(string $userId, int $sessionId): bool { } } + /** + * @param string $userId + * @param int $sessionId + * @return Session + * @throws DoesNotExistException + * @throws MultipleObjectsReturnedException + * @throws Exception + */ + public function getUserSession(string $userId, int $sessionId): Session { + $qb = $this->db->getQueryBuilder(); + $qb->select('id', 'title', 'timestamp') + ->from($this->getTableName()) + ->where($qb->expr()->eq('id', $qb->createPositionalParameter($sessionId, IQueryBuilder::PARAM_INT))) + ->andWhere($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId, IQueryBuilder::PARAM_STR))); + + return $this->findEntity($qb); + } + /** * @param string $userId * @return array diff --git a/lib/Listener/ChattyLLMTaskListener.php b/lib/Listener/ChattyLLMTaskListener.php index a5fb1ea8..b79deade 100644 --- a/lib/Listener/ChattyLLMTaskListener.php +++ b/lib/Listener/ChattyLLMTaskListener.php @@ -7,6 +7,7 @@ use OCA\Assistant\AppInfo\Application; use OCA\Assistant\Db\ChattyLLM\Message; use OCA\Assistant\Db\ChattyLLM\MessageMapper; +use OCA\Assistant\Db\ChattyLLM\SessionMapper; use OCP\EventDispatcher\Event; use OCP\EventDispatcher\IEventListener; use OCP\TaskProcessing\Events\TaskSuccessfulEvent; @@ -19,6 +20,7 @@ class ChattyLLMTaskListener implements IEventListener { public function __construct( private MessageMapper $messageMapper, + private SessionMapper $sessionMapper, private LoggerInterface $logger, ) { } @@ -31,23 +33,32 @@ public function handle(Event $event): void { $task = $event->getTask(); $customId = $task->getCustomId(); $appId = $task->getAppId(); - if ($appId !== (Application::APP_ID . ':chatty-llm') - || $customId === null - || !preg_match('/^chatty-llm:(\d+)$/', $customId, $matches) - ) { + + if ($customId === null || $appId !== (Application::APP_ID . ':chatty-llm')) { return; } - $sessionId = (int)$matches[1]; - - $message = new Message(); - $message->setSessionId($sessionId); - $message->setRole('assistant'); - $message->setContent(trim($task->getOutput()['output'] ?? '')); - $message->setTimestamp(time()); - try { - $this->messageMapper->insert($message); - } catch (\OCP\DB\Exception $e) { - $this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]); + + // title generation + if (preg_match('/^chatty-title:(\d+)$/', $customId, $matches)) { + $sessionId = (int)$matches[1]; + $title = trim($task->getOutput()['output'] ?? ''); + $this->sessionMapper->updateSessionTitle($task->getUserId(), $sessionId, $title); + } + + // message generation + if (preg_match('/^chatty-llm:(\d+)$/', $customId, $matches)) { + $sessionId = (int)$matches[1]; + + $message = new Message(); + $message->setSessionId($sessionId); + $message->setRole('assistant'); + $message->setContent(trim($task->getOutput()['output'] ?? '')); + $message->setTimestamp(time()); + try { + $this->messageMapper->insert($message); + } catch (\OCP\DB\Exception $e) { + $this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]); + } } } } diff --git a/src/components/ChattyLLM/ChattyLLMInputForm.vue b/src/components/ChattyLLM/ChattyLLMInputForm.vue index 6dd099e5..6373a041 100644 --- a/src/components/ChattyLLM/ChattyLLMInputForm.vue +++ b/src/components/ChattyLLM/ChattyLLMInputForm.vue @@ -220,6 +220,7 @@ export default { watch: { async active() { this.loading.llmGeneration = false + this.loading.titleGeneration = false this.allMessagesLoaded = false this.chatContent = '' this.msgCursor = 0 @@ -238,22 +239,45 @@ export default { } // start polling in case a message is currently being generated try { - const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId: this.active.id } }) - console.debug('check session response:', checkSessionResponse) - if (checkSessionResponse.data.taskId === null) { - return + const sessionId = this.active.id + const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId } }) + if (checkSessionResponse.data?.sessionTitle && checkSessionResponse.data?.sessionTitle !== this.active.title) { + this.active.title = checkSessionResponse.data?.sessionTitle + console.debug('update session title with check result') } - try { - this.loading.llmGeneration = true - const message = await this.pollGenerationTask(checkSessionResponse.data.taskId, this.active.id) - console.debug('checkTaskPolling result:', message) - this.messages.push(message) - this.scrollToBottom() - } catch (error) { - console.error('checkGenerationTask error:', error) - showError(t('assistant', 'Error generating a response')) - } finally { - this.loading.llmGeneration = false + console.debug('check session response:', checkSessionResponse.data) + if (checkSessionResponse.data.messageTaskId !== null) { + try { + this.loading.llmGeneration = true + const message = await this.pollGenerationTask(checkSessionResponse.data.messageTaskId, sessionId) + console.debug('checkTaskPolling result:', message) + this.messages.push(message) + this.scrollToBottom() + } catch (error) { + console.error('checkGenerationTask error:', error) + showError(t('assistant', 'Error generating a response')) + } finally { + this.loading.llmGeneration = false + } + } else if (checkSessionResponse.data.titleTaskId !== null) { + try { + this.loading.titleGeneration = true + const titleResponse = await this.pollTitleGenerationTask(checkSessionResponse.data.titleTaskId, sessionId) + console.debug('checkTaskPolling result:', titleResponse) + if (titleResponse?.data?.result == null) { + throw new Error('No title generated, response:', titleResponse) + } + + const session = this.sessions.find(s => s.id === sessionId) + if (session) { + session.title = titleResponse?.data?.result + } + } catch (error) { + console.error('onCheckSessionTitle error:', error) + showError(error?.response?.data?.error ?? t('assistant', 'Error getting the generated title for the conversation')) + } finally { + this.loading.titleGeneration = false + } } } catch (error) { console.error('check session error:', error) @@ -367,18 +391,17 @@ export default { async onGenerateSessionTitle() { try { this.loading.titleGeneration = true - const response = await axios.get(getChatURL('/generate_title'), { params: { sessionId: this.active.id } }) - const titleResponse = await this.pollTitleGenerationTask(response.data.taskId) + const sessionId = this.active.id + const response = await axios.get(getChatURL('/generate_title'), { params: { sessionId } }) + const titleResponse = await this.pollTitleGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', titleResponse) if (titleResponse?.data?.result == null) { throw new Error('No title generated, response:', response) } - for (const session of this.sessions) { - if (session.id === this.active.id) { - session.title = titleResponse?.data?.result - break - } + const session = this.sessions.find(s => s.id === sessionId) + if (session) { + session.title = titleResponse?.data?.result } } catch (error) { console.error('onGenerateSessionTitle error:', error) @@ -592,12 +615,12 @@ export default { if (sessionId === this.active.id) { resolve(response.data) } else { - console.debug('Ignoring received a message for session ' + sessionId + ' that is not selected anymore') + console.debug('Ignoring received message for session ' + sessionId + ' that is not selected anymore') // should we reject here? } }).catch(error => { if (sessionId !== this.active.id) { - console.debug('Stop polling session ' + sessionId + ' because it is not selected anymore') + console.debug('Stop polling messages for session ' + sessionId + ' because it is not selected anymore') clearInterval(this.pollMessageGenerationTimerId) } // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) @@ -606,30 +629,39 @@ export default { clearInterval(this.pollMessageGenerationTimerId) reject(new Error('Message generation task check failed')) } else { - console.debug('checkTaskPolling, task is still scheduled or running', error) + console.debug('checkTaskPolling, task is still scheduled or running') } }) }, 2000) }) }, - async pollTitleGenerationTask(taskId) { + async pollTitleGenerationTask(taskId, sessionId) { return new Promise((resolve, reject) => { this.pollTitleGenerationTimerId = setInterval(() => { axios.get( getChatURL('/check_title_generation'), - { params: { taskId, sessionId: this.active.id } }, + { params: { taskId, sessionId } }, ).then(response => { + if (sessionId === this.active.id) { + resolve(response) + } else { + console.debug('Ignoring received title for session ' + sessionId + ' that is not selected anymore') + // should we reject here? + } clearInterval(this.pollTitleGenerationTimerId) - resolve(response) }).catch(error => { + if (sessionId !== this.active.id) { + console.debug('Stop polling title for session ' + sessionId + ' because it is not selected anymore') + clearInterval(this.pollTitleGenerationTimerId) + } // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) if (error.response?.status !== 417) { console.error('checkTaskPolling error', error) clearInterval(this.pollTitleGenerationTimerId) reject(new Error('Title generation task check failed')) } else { - console.debug('checkTaskPolling, task is still scheduled or running', error) + console.debug('checkTaskPolling, task is still scheduled or running') } }) }, 2000) From 23eaccd9dcdf26820a487e210efdfea9ba6e020c Mon Sep 17 00:00:00 2001 From: Julien Veyssier Date: Mon, 25 Nov 2024 17:26:16 +0100 Subject: [PATCH 3/3] address review comments Signed-off-by: Julien Veyssier --- .../ChattyLLM/ChattyLLMInputForm.vue | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/components/ChattyLLM/ChattyLLMInputForm.vue b/src/components/ChattyLLM/ChattyLLMInputForm.vue index 6373a041..0f5fb028 100644 --- a/src/components/ChattyLLM/ChattyLLMInputForm.vue +++ b/src/components/ChattyLLM/ChattyLLMInputForm.vue @@ -219,8 +219,9 @@ export default { watch: { async active() { - this.loading.llmGeneration = false - this.loading.titleGeneration = false + // set loading to true since we know we check that + this.loading.llmGeneration = true + this.loading.titleGeneration = true this.allMessagesLoaded = false this.chatContent = '' this.msgCursor = 0 @@ -248,7 +249,6 @@ export default { console.debug('check session response:', checkSessionResponse.data) if (checkSessionResponse.data.messageTaskId !== null) { try { - this.loading.llmGeneration = true const message = await this.pollGenerationTask(checkSessionResponse.data.messageTaskId, sessionId) console.debug('checkTaskPolling result:', message) this.messages.push(message) @@ -256,12 +256,10 @@ export default { } catch (error) { console.error('checkGenerationTask error:', error) showError(t('assistant', 'Error generating a response')) - } finally { - this.loading.llmGeneration = false } - } else if (checkSessionResponse.data.titleTaskId !== null) { + } + if (checkSessionResponse.data.titleTaskId !== null) { try { - this.loading.titleGeneration = true const titleResponse = await this.pollTitleGenerationTask(checkSessionResponse.data.titleTaskId, sessionId) console.debug('checkTaskPolling result:', titleResponse) if (titleResponse?.data?.result == null) { @@ -275,13 +273,14 @@ export default { } catch (error) { console.error('onCheckSessionTitle error:', error) showError(error?.response?.data?.error ?? t('assistant', 'Error getting the generated title for the conversation')) - } finally { - this.loading.titleGeneration = false } } } catch (error) { console.error('check session error:', error) showError(t('assistant', 'Error checking if the session is thinking')) + } finally { + this.loading.llmGeneration = false + this.loading.titleGeneration = false } }, }, @@ -607,6 +606,11 @@ export default { async pollGenerationTask(taskId, sessionId) { return new Promise((resolve, reject) => { this.pollMessageGenerationTimerId = setInterval(() => { + if (sessionId !== this.active.id) { + console.debug('Stop polling messages for session ' + sessionId + ' because it is not selected anymore') + clearInterval(this.pollMessageGenerationTimerId) + return + } axios.get( getChatURL('/check_generation'), { params: { taskId, sessionId } }, @@ -619,10 +623,6 @@ export default { // should we reject here? } }).catch(error => { - if (sessionId !== this.active.id) { - console.debug('Stop polling messages for session ' + sessionId + ' because it is not selected anymore') - clearInterval(this.pollMessageGenerationTimerId) - } // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) if (error.response?.status !== 417) { console.error('checkTaskPolling error', error) @@ -639,6 +639,11 @@ export default { async pollTitleGenerationTask(taskId, sessionId) { return new Promise((resolve, reject) => { this.pollTitleGenerationTimerId = setInterval(() => { + if (sessionId !== this.active.id) { + console.debug('Stop polling title for session ' + sessionId + ' because it is not selected anymore') + clearInterval(this.pollTitleGenerationTimerId) + return + } axios.get( getChatURL('/check_title_generation'), { params: { taskId, sessionId } }, @@ -651,10 +656,6 @@ export default { } clearInterval(this.pollTitleGenerationTimerId) }).catch(error => { - if (sessionId !== this.active.id) { - console.debug('Stop polling title for session ' + sessionId + ' because it is not selected anymore') - clearInterval(this.pollTitleGenerationTimerId) - } // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) if (error.response?.status !== 417) { console.error('checkTaskPolling error', error)