From 06201f6ce2536dd09bd5a74ea4b83ef02e297f4d Mon Sep 17 00:00:00 2001 From: Julien Veyssier Date: Wed, 20 Nov 2024 16:29:10 +0100 Subject: [PATCH] fix polling when switching sessions in the frontend, prevent scheduling multiple llm tasks for one session Signed-off-by: Julien Veyssier --- appinfo/routes.php | 1 + lib/AppInfo/Application.php | 2 + lib/Controller/ChattyLLMController.php | 76 +++++++++++++++++-- lib/Listener/ChattyLLMTaskListener.php | 53 +++++++++++++ .../ChattyLLM/ChattyLLMInputForm.vue | 63 +++++++++++---- tests/psalm-baseline.xml | 4 +- 6 files changed, 178 insertions(+), 21 deletions(-) create mode 100644 lib/Listener/ChattyLLMTaskListener.php diff --git a/appinfo/routes.php b/appinfo/routes.php index 19443cf5..3488d40f 100644 --- a/appinfo/routes.php +++ b/appinfo/routes.php @@ -24,6 +24,7 @@ ['name' => 'chattyLLM#getMessages', 'url' => '/chat/messages', 'verb' => 'GET'], ['name' => 'chattyLLM#generateForSession', 'url' => '/chat/generate', 'verb' => 'GET'], ['name' => 'chattyLLM#regenerateForSession', 'url' => '/chat/regenerate', 'verb' => 'GET'], + ['name' => 'chattyLLM#checkMessageGenerationSession', 'url' => '/chat/check_session', 'verb' => 'GET'], ['name' => 'chattyLLM#checkMessageGenerationTask', 'url' => '/chat/check_generation', 'verb' => 'GET'], ['name' => 'chattyLLM#generateTitle', 'url' => '/chat/generate_title', 'verb' => 'GET'], ['name' => 'chattyLLM#checkTitleGenerationTask', 'url' => '/chat/check_title_generation', 'verb' => 'GET'], diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index 8005fb48..c0a0ff3a 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -4,6 +4,7 @@ use OCA\Assistant\Capabilities; use OCA\Assistant\Listener\BeforeTemplateRenderedListener; +use OCA\Assistant\Listener\ChattyLLMTaskListener; use OCA\Assistant\Listener\CSPListener; use OCA\Assistant\Listener\FreePrompt\FreePromptReferenceListener; use OCA\Assistant\Listener\SpeechToText\SpeechToTextReferenceListener; @@ -55,6 +56,7 @@ public function register(IRegistrationContext $context): void { $context->registerEventListener(TaskSuccessfulEvent::class, TaskSuccessfulListener::class); $context->registerEventListener(TaskFailedEvent::class, TaskFailedListener::class); + $context->registerEventListener(TaskSuccessfulEvent::class, ChattyLLMTaskListener::class); $context->registerNotifierService(Notifier::class); diff --git a/lib/Controller/ChattyLLMController.php b/lib/Controller/ChattyLLMController.php index 32183246..071211dd 100644 --- a/lib/Controller/ChattyLLMController.php +++ b/lib/Controller/ChattyLLMController.php @@ -297,7 +297,11 @@ public function generateForSession(int $sessionId): JSONResponse { . PHP_EOL . 'assistant: '; - $taskId = $this->scheduleLLMTask($stichedPrompt); + try { + $taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId); + } catch (\Exception $e) { + return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST); + } return new JSONResponse(['taskId' => $taskId]); } @@ -374,7 +378,7 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes $message->setRole('assistant'); $message->setContent(trim($task->getOutput()['output'] ?? '')); $message->setTimestamp(time()); - $this->messageMapper->insert($message); + // do not insert here, it is done by the listener return new JSONResponse($message); } catch (\OCP\DB\Exception $e) { $this->logger->warning('Failed to add a chat message into DB', ['exception' => $e]); @@ -388,6 +392,46 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes return new JSONResponse(['error' => 'unknown_error', 'task_status' => $task->getstatus()], Http::STATUS_BAD_REQUEST); } + /** + * Check the status of a session + * + * Used by the frontend to determine if it should poll a generation task status. + * + * @param int $sessionId + * @return JSONResponse + * @throws \JsonException + * @throws \OCP\DB\Exception + */ + #[NoAdminRequired] + public function checkMessageGenerationSession(int $sessionId): JSONResponse { + if ($this->userId === null) { + return new JSONResponse(['error' => $this->l10n->t('User not logged in')], Http::STATUS_UNAUTHORIZED); + } + + $sessionExists = $this->sessionMapper->exists($this->userId, $sessionId); + if (!$sessionExists) { + return new JSONResponse(['error' => $this->l10n->t('Session not found')], Http::STATUS_NOT_FOUND); + } + + try { + $tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId); + } catch (\OCP\TaskProcessing\Exception\Exception $e) { + return new JSONResponse(['error' => 'task_query_failed'], Http::STATUS_BAD_REQUEST); + } + $tasks = array_filter($tasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + if (empty($tasks)) { + return new JSONResponse([ + 'taskId' => null, + ]); + } + $task = array_pop($tasks); + return new JSONResponse([ + 'taskId' => $task->getId(), + ]); + } + /** * Schedule a task to generate a title for the chat session * @@ -430,7 +474,11 @@ public function generateTitle(int $sessionId): JSONResponse { . PHP_EOL . PHP_EOL . $userInstructions; - $taskId = $this->scheduleLLMTask($stichedPrompt); + try { + $taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId, false); + } catch (\Exception $e) { + return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST); + } return new JSONResponse(['taskId' => $taskId]); } catch (\OCP\DB\Exception $e) { $this->logger->warning('Failed to generate a title for the chat session', ['exception' => $e]); @@ -525,14 +573,32 @@ private function getStichedMessages(int $sessionId): string { * Schedule the LLM task * * @param string $content + * @param int $sessionId + * @param bool $isMessage * @return int|null * @throws Exception * @throws PreConditionNotMetException * @throws UnauthorizedException * @throws ValidationException + * @throws \JsonException */ - private function scheduleLLMTask(string $content): ?int { - $task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId); + private function scheduleLLMTask(string $content, int $sessionId, bool $isMessage = true): ?int { + $customId = ($isMessage + ? 'chatty-llm:' + : 'chatty-title:') . $sessionId; + try { + $tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', $customId); + } catch (\OCP\TaskProcessing\Exception\Exception $e) { + throw new \Exception('task_query_failed'); + } + $tasks = array_filter($tasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + // prevent scheduling multiple llm tasks simultaneously for one session + if (!empty($tasks)) { + throw new \Exception('session_already_thinking'); + } + $task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId); $this->taskProcessingManager->scheduleTask($task); return $task->getId(); } diff --git a/lib/Listener/ChattyLLMTaskListener.php b/lib/Listener/ChattyLLMTaskListener.php new file mode 100644 index 00000000..a5fb1ea8 --- /dev/null +++ b/lib/Listener/ChattyLLMTaskListener.php @@ -0,0 +1,53 @@ + + */ +class ChattyLLMTaskListener implements IEventListener { + + public function __construct( + private MessageMapper $messageMapper, + private LoggerInterface $logger, + ) { + } + + public function handle(Event $event): void { + if (!($event instanceof TaskSuccessfulEvent)) { + return; + } + + $task = $event->getTask(); + $customId = $task->getCustomId(); + $appId = $task->getAppId(); + if ($appId !== (Application::APP_ID . ':chatty-llm') + || $customId === null + || !preg_match('/^chatty-llm:(\d+)$/', $customId, $matches) + ) { + return; + } + $sessionId = (int)$matches[1]; + + $message = new Message(); + $message->setSessionId($sessionId); + $message->setRole('assistant'); + $message->setContent(trim($task->getOutput()['output'] ?? '')); + $message->setTimestamp(time()); + try { + $this->messageMapper->insert($message); + } catch (\OCP\DB\Exception $e) { + $this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]); + } + } +} diff --git a/src/components/ChattyLLM/ChattyLLMInputForm.vue b/src/components/ChattyLLM/ChattyLLMInputForm.vue index 9c33964a..6dd099e5 100644 --- a/src/components/ChattyLLM/ChattyLLMInputForm.vue +++ b/src/components/ChattyLLM/ChattyLLMInputForm.vue @@ -219,6 +219,7 @@ export default { watch: { async active() { + this.loading.llmGeneration = false this.allMessagesLoaded = false this.chatContent = '' this.msgCursor = 0 @@ -226,13 +227,37 @@ export default { this.editingTitle = false this.$refs.inputComponent.focus() - if (this.active != null && !this.loading.newSession) { + if (this.active !== null && !this.loading.newSession) { await this.fetchMessages() this.scrollToBottom() } else { // when no active session or creating a new session this.allMessagesLoaded = true this.loading.newSession = false + return + } + // start polling in case a message is currently being generated + try { + const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId: this.active.id } }) + console.debug('check session response:', checkSessionResponse) + if (checkSessionResponse.data.taskId === null) { + return + } + try { + this.loading.llmGeneration = true + const message = await this.pollGenerationTask(checkSessionResponse.data.taskId, this.active.id) + console.debug('checkTaskPolling result:', message) + this.messages.push(message) + this.scrollToBottom() + } catch (error) { + console.error('checkGenerationTask error:', error) + showError(t('assistant', 'Error generating a response')) + } finally { + this.loading.llmGeneration = false + } + } catch (error) { + console.error('check session error:', error) + showError(t('assistant', 'Error checking if the session is thinking')) } }, }, @@ -328,7 +353,7 @@ export default { this.messages.push({ role, content, timestamp }) this.chatContent = '' this.scrollToBottom() - await this.newMessage(role, content, timestamp) + await this.newMessage(role, content, timestamp, this.active.id) }, onLoadOlderMessages() { @@ -466,13 +491,13 @@ export default { } }, - async newMessage(role, content, timestamp) { + async newMessage(role, content, timestamp, sessionId) { try { this.loading.newHumanMessage = true const firstHumanMessage = this.messages.length === 1 && this.messages[0].role === Roles.HUMAN const response = await axios.put(getChatURL('/new_message'), { - sessionId: this.active.id, + sessionId, role, content, timestamp, @@ -485,11 +510,11 @@ export default { this.messages[this.messages.length - 1] = response.data if (firstHumanMessage) { - const session = this.sessions.find((session) => session.id === this.active.id) + const session = this.sessions.find((session) => session.id === sessionId) session.title = content } - await this.runGenerationTask() + await this.runGenerationTask(sessionId) } catch (error) { this.loading.newHumanMessage = false console.error('newMessage error:', error) @@ -521,12 +546,12 @@ export default { } }, - async runGenerationTask() { + async runGenerationTask(sessionId) { try { this.loading.llmGeneration = true - const response = await axios.get(getChatURL('/generate'), { params: { sessionId: this.active.id } }) + const response = await axios.get(getChatURL('/generate'), { params: { sessionId } }) console.debug('scheduleGenerationTask response:', response) - const message = await this.pollGenerationTask(response.data.taskId) + const message = await this.pollGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', message) this.messages.push(message) this.scrollToBottom() @@ -540,10 +565,11 @@ export default { async runRegenerationTask(messageId) { try { + const sessionId = this.active.id this.loading.llmGeneration = true - const response = await axios.get(getChatURL('/regenerate'), { params: { messageId, sessionId: this.active.id } }) + const response = await axios.get(getChatURL('/regenerate'), { params: { messageId, sessionId } }) console.debug('scheduleRegenerationTask response:', response) - const message = await this.pollGenerationTask(response.data.taskId) + const message = await this.pollGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', message) this.messages[this.messages.length - 1] = message this.scrollToBottom() @@ -555,16 +581,25 @@ export default { } }, - async pollGenerationTask(taskId) { + async pollGenerationTask(taskId, sessionId) { return new Promise((resolve, reject) => { this.pollMessageGenerationTimerId = setInterval(() => { axios.get( getChatURL('/check_generation'), - { params: { taskId, sessionId: this.active.id } }, + { params: { taskId, sessionId } }, ).then(response => { clearInterval(this.pollMessageGenerationTimerId) - resolve(response.data) + if (sessionId === this.active.id) { + resolve(response.data) + } else { + console.debug('Ignoring received a message for session ' + sessionId + ' that is not selected anymore') + // should we reject here? + } }).catch(error => { + if (sessionId !== this.active.id) { + console.debug('Stop polling session ' + sessionId + ' because it is not selected anymore') + clearInterval(this.pollMessageGenerationTimerId) + } // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) if (error.response?.status !== 417) { console.error('checkTaskPolling error', error) diff --git a/tests/psalm-baseline.xml b/tests/psalm-baseline.xml index 1b3c8054..a1dc3d18 100644 --- a/tests/psalm-baseline.xml +++ b/tests/psalm-baseline.xml @@ -1,5 +1,5 @@ - + - $content], Application::APP_ID . ':chatty-llm', $this->userId)]]> + $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId)]]>