diff --git a/appinfo/routes.php b/appinfo/routes.php index 19443cf5..7588993c 100644 --- a/appinfo/routes.php +++ b/appinfo/routes.php @@ -24,6 +24,7 @@ ['name' => 'chattyLLM#getMessages', 'url' => '/chat/messages', 'verb' => 'GET'], ['name' => 'chattyLLM#generateForSession', 'url' => '/chat/generate', 'verb' => 'GET'], ['name' => 'chattyLLM#regenerateForSession', 'url' => '/chat/regenerate', 'verb' => 'GET'], + ['name' => 'chattyLLM#checkSession', 'url' => '/chat/check_session', 'verb' => 'GET'], ['name' => 'chattyLLM#checkMessageGenerationTask', 'url' => '/chat/check_generation', 'verb' => 'GET'], ['name' => 'chattyLLM#generateTitle', 'url' => '/chat/generate_title', 'verb' => 'GET'], ['name' => 'chattyLLM#checkTitleGenerationTask', 'url' => '/chat/check_title_generation', 'verb' => 'GET'], diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index 8005fb48..c0a0ff3a 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -4,6 +4,7 @@ use OCA\Assistant\Capabilities; use OCA\Assistant\Listener\BeforeTemplateRenderedListener; +use OCA\Assistant\Listener\ChattyLLMTaskListener; use OCA\Assistant\Listener\CSPListener; use OCA\Assistant\Listener\FreePrompt\FreePromptReferenceListener; use OCA\Assistant\Listener\SpeechToText\SpeechToTextReferenceListener; @@ -55,6 +56,7 @@ public function register(IRegistrationContext $context): void { $context->registerEventListener(TaskSuccessfulEvent::class, TaskSuccessfulListener::class); $context->registerEventListener(TaskFailedEvent::class, TaskFailedListener::class); + $context->registerEventListener(TaskSuccessfulEvent::class, ChattyLLMTaskListener::class); $context->registerNotifierService(Notifier::class); diff --git a/lib/Controller/ChattyLLMController.php b/lib/Controller/ChattyLLMController.php index 32183246..e7e87f79 100644 --- a/lib/Controller/ChattyLLMController.php +++ b/lib/Controller/ChattyLLMController.php @@ -297,7 +297,11 @@ public function generateForSession(int $sessionId): JSONResponse { . PHP_EOL . 'assistant: '; - $taskId = $this->scheduleLLMTask($stichedPrompt); + try { + $taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId); + } catch (\Exception $e) { + return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST); + } return new JSONResponse(['taskId' => $taskId]); } @@ -374,7 +378,7 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes $message->setRole('assistant'); $message->setContent(trim($task->getOutput()['output'] ?? '')); $message->setTimestamp(time()); - $this->messageMapper->insert($message); + // do not insert here, it is done by the listener return new JSONResponse($message); } catch (\OCP\DB\Exception $e) { $this->logger->warning('Failed to add a chat message into DB', ['exception' => $e]); @@ -388,6 +392,56 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes return new JSONResponse(['error' => 'unknown_error', 'task_status' => $task->getstatus()], Http::STATUS_BAD_REQUEST); } + /** + * Check the status of a session + * + * Used by the frontend to determine if it should poll a generation task status. + * + * @param int $sessionId + * @return JSONResponse + * @throws \JsonException + * @throws \OCP\DB\Exception + */ + #[NoAdminRequired] + public function checkSession(int $sessionId): JSONResponse { + if ($this->userId === null) { + return new JSONResponse(['error' => $this->l10n->t('User not logged in')], Http::STATUS_UNAUTHORIZED); + } + + $sessionExists = $this->sessionMapper->exists($this->userId, $sessionId); + if (!$sessionExists) { + return new JSONResponse(['error' => $this->l10n->t('Session not found')], Http::STATUS_NOT_FOUND); + } + + try { + $messageTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId); + $titleTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-title:' . $sessionId); + } catch (\OCP\TaskProcessing\Exception\Exception $e) { + return new JSONResponse(['error' => 'task_query_failed'], Http::STATUS_BAD_REQUEST); + } + $messageTasks = array_filter($messageTasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + $titleTasks = array_filter($titleTasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + $session = $this->sessionMapper->getUserSession($this->userId, $sessionId); + $responseData = [ + 'messageTaskId' => null, + 'titleTaskId' => null, + 'sessionTitle' => $session->getTitle(), + ]; + if (!empty($messageTasks)) { + $task = array_pop($messageTasks); + $responseData['messageTaskId'] = $task->getId(); + } + if (!empty($titleTasks)) { + $task = array_pop($titleTasks); + $responseData['titleTaskId'] = $task->getId(); + } + return new JSONResponse($responseData); + } + /** * Schedule a task to generate a title for the chat session * @@ -430,7 +484,11 @@ public function generateTitle(int $sessionId): JSONResponse { . PHP_EOL . PHP_EOL . $userInstructions; - $taskId = $this->scheduleLLMTask($stichedPrompt); + try { + $taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId, false); + } catch (\Exception $e) { + return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST); + } return new JSONResponse(['taskId' => $taskId]); } catch (\OCP\DB\Exception $e) { $this->logger->warning('Failed to generate a title for the chat session', ['exception' => $e]); @@ -475,8 +533,7 @@ public function checkTitleGenerationTask(int $taskId, int $sessionId): JSONRespo $title = str_replace('"', '', $title); $title = explode(PHP_EOL, $title)[0]; $title = trim($title); - - $this->sessionMapper->updateSessionTitle($this->userId, $sessionId, $title); + // do not write the title here since it's done in the listener return new JSONResponse(['result' => $title]); } catch (\OCP\DB\Exception $e) { @@ -525,14 +582,32 @@ private function getStichedMessages(int $sessionId): string { * Schedule the LLM task * * @param string $content + * @param int $sessionId + * @param bool $isMessage * @return int|null * @throws Exception * @throws PreConditionNotMetException * @throws UnauthorizedException * @throws ValidationException + * @throws \JsonException */ - private function scheduleLLMTask(string $content): ?int { - $task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId); + private function scheduleLLMTask(string $content, int $sessionId, bool $isMessage = true): ?int { + $customId = ($isMessage + ? 'chatty-llm:' + : 'chatty-title:') . $sessionId; + try { + $tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', $customId); + } catch (\OCP\TaskProcessing\Exception\Exception $e) { + throw new \Exception('task_query_failed'); + } + $tasks = array_filter($tasks, static function (Task $task) { + return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED; + }); + // prevent scheduling multiple llm tasks simultaneously for one session + if (!empty($tasks)) { + throw new \Exception('session_already_thinking'); + } + $task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId); $this->taskProcessingManager->scheduleTask($task); return $task->getId(); } diff --git a/lib/Db/ChattyLLM/Session.php b/lib/Db/ChattyLLM/Session.php index a6c19a1a..14a5e7c0 100644 --- a/lib/Db/ChattyLLM/Session.php +++ b/lib/Db/ChattyLLM/Session.php @@ -31,7 +31,7 @@ /** * @method \string getUserId() * @method \void setUserId(string $userId) - * @method \?string getTitle() + * @method \string|null getTitle() * @method \void setTitle(?string $title) * @method \int|null getTimestamp() * @method \void setTimestamp(?int $timestamp) diff --git a/lib/Db/ChattyLLM/SessionMapper.php b/lib/Db/ChattyLLM/SessionMapper.php index 60418d03..276fca67 100644 --- a/lib/Db/ChattyLLM/SessionMapper.php +++ b/lib/Db/ChattyLLM/SessionMapper.php @@ -25,7 +25,10 @@ namespace OCA\Assistant\Db\ChattyLLM; +use OCP\AppFramework\Db\DoesNotExistException; +use OCP\AppFramework\Db\MultipleObjectsReturnedException; use OCP\AppFramework\Db\QBMapper; +use OCP\DB\Exception; use OCP\DB\QueryBuilder\IQueryBuilder; use OCP\IDBConnection; @@ -59,6 +62,24 @@ public function exists(string $userId, int $sessionId): bool { } } + /** + * @param string $userId + * @param int $sessionId + * @return Session + * @throws DoesNotExistException + * @throws MultipleObjectsReturnedException + * @throws Exception + */ + public function getUserSession(string $userId, int $sessionId): Session { + $qb = $this->db->getQueryBuilder(); + $qb->select('id', 'title', 'timestamp') + ->from($this->getTableName()) + ->where($qb->expr()->eq('id', $qb->createPositionalParameter($sessionId, IQueryBuilder::PARAM_INT))) + ->andWhere($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId, IQueryBuilder::PARAM_STR))); + + return $this->findEntity($qb); + } + /** * @param string $userId * @return array diff --git a/lib/Listener/ChattyLLMTaskListener.php b/lib/Listener/ChattyLLMTaskListener.php new file mode 100644 index 00000000..b79deade --- /dev/null +++ b/lib/Listener/ChattyLLMTaskListener.php @@ -0,0 +1,64 @@ + + */ +class ChattyLLMTaskListener implements IEventListener { + + public function __construct( + private MessageMapper $messageMapper, + private SessionMapper $sessionMapper, + private LoggerInterface $logger, + ) { + } + + public function handle(Event $event): void { + if (!($event instanceof TaskSuccessfulEvent)) { + return; + } + + $task = $event->getTask(); + $customId = $task->getCustomId(); + $appId = $task->getAppId(); + + if ($customId === null || $appId !== (Application::APP_ID . ':chatty-llm')) { + return; + } + + // title generation + if (preg_match('/^chatty-title:(\d+)$/', $customId, $matches)) { + $sessionId = (int)$matches[1]; + $title = trim($task->getOutput()['output'] ?? ''); + $this->sessionMapper->updateSessionTitle($task->getUserId(), $sessionId, $title); + } + + // message generation + if (preg_match('/^chatty-llm:(\d+)$/', $customId, $matches)) { + $sessionId = (int)$matches[1]; + + $message = new Message(); + $message->setSessionId($sessionId); + $message->setRole('assistant'); + $message->setContent(trim($task->getOutput()['output'] ?? '')); + $message->setTimestamp(time()); + try { + $this->messageMapper->insert($message); + } catch (\OCP\DB\Exception $e) { + $this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]); + } + } + } +} diff --git a/src/components/ChattyLLM/ChattyLLMInputForm.vue b/src/components/ChattyLLM/ChattyLLMInputForm.vue index 9c33964a..0f5fb028 100644 --- a/src/components/ChattyLLM/ChattyLLMInputForm.vue +++ b/src/components/ChattyLLM/ChattyLLMInputForm.vue @@ -219,6 +219,9 @@ export default { watch: { async active() { + // set loading to true since we know we check that + this.loading.llmGeneration = true + this.loading.titleGeneration = true this.allMessagesLoaded = false this.chatContent = '' this.msgCursor = 0 @@ -226,13 +229,58 @@ export default { this.editingTitle = false this.$refs.inputComponent.focus() - if (this.active != null && !this.loading.newSession) { + if (this.active !== null && !this.loading.newSession) { await this.fetchMessages() this.scrollToBottom() } else { // when no active session or creating a new session this.allMessagesLoaded = true this.loading.newSession = false + return + } + // start polling in case a message is currently being generated + try { + const sessionId = this.active.id + const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId } }) + if (checkSessionResponse.data?.sessionTitle && checkSessionResponse.data?.sessionTitle !== this.active.title) { + this.active.title = checkSessionResponse.data?.sessionTitle + console.debug('update session title with check result') + } + console.debug('check session response:', checkSessionResponse.data) + if (checkSessionResponse.data.messageTaskId !== null) { + try { + const message = await this.pollGenerationTask(checkSessionResponse.data.messageTaskId, sessionId) + console.debug('checkTaskPolling result:', message) + this.messages.push(message) + this.scrollToBottom() + } catch (error) { + console.error('checkGenerationTask error:', error) + showError(t('assistant', 'Error generating a response')) + } + } + if (checkSessionResponse.data.titleTaskId !== null) { + try { + const titleResponse = await this.pollTitleGenerationTask(checkSessionResponse.data.titleTaskId, sessionId) + console.debug('checkTaskPolling result:', titleResponse) + if (titleResponse?.data?.result == null) { + throw new Error('No title generated, response:', titleResponse) + } + + const session = this.sessions.find(s => s.id === sessionId) + if (session) { + session.title = titleResponse?.data?.result + } + } catch (error) { + console.error('onCheckSessionTitle error:', error) + showError(error?.response?.data?.error ?? t('assistant', 'Error getting the generated title for the conversation')) + } + } + } catch (error) { + console.error('check session error:', error) + showError(t('assistant', 'Error checking if the session is thinking')) + } finally { + this.loading.llmGeneration = false + this.loading.titleGeneration = false } }, }, @@ -328,7 +376,7 @@ export default { this.messages.push({ role, content, timestamp }) this.chatContent = '' this.scrollToBottom() - await this.newMessage(role, content, timestamp) + await this.newMessage(role, content, timestamp, this.active.id) }, onLoadOlderMessages() { @@ -342,18 +390,17 @@ export default { async onGenerateSessionTitle() { try { this.loading.titleGeneration = true - const response = await axios.get(getChatURL('/generate_title'), { params: { sessionId: this.active.id } }) - const titleResponse = await this.pollTitleGenerationTask(response.data.taskId) + const sessionId = this.active.id + const response = await axios.get(getChatURL('/generate_title'), { params: { sessionId } }) + const titleResponse = await this.pollTitleGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', titleResponse) if (titleResponse?.data?.result == null) { throw new Error('No title generated, response:', response) } - for (const session of this.sessions) { - if (session.id === this.active.id) { - session.title = titleResponse?.data?.result - break - } + const session = this.sessions.find(s => s.id === sessionId) + if (session) { + session.title = titleResponse?.data?.result } } catch (error) { console.error('onGenerateSessionTitle error:', error) @@ -466,13 +513,13 @@ export default { } }, - async newMessage(role, content, timestamp) { + async newMessage(role, content, timestamp, sessionId) { try { this.loading.newHumanMessage = true const firstHumanMessage = this.messages.length === 1 && this.messages[0].role === Roles.HUMAN const response = await axios.put(getChatURL('/new_message'), { - sessionId: this.active.id, + sessionId, role, content, timestamp, @@ -485,11 +532,11 @@ export default { this.messages[this.messages.length - 1] = response.data if (firstHumanMessage) { - const session = this.sessions.find((session) => session.id === this.active.id) + const session = this.sessions.find((session) => session.id === sessionId) session.title = content } - await this.runGenerationTask() + await this.runGenerationTask(sessionId) } catch (error) { this.loading.newHumanMessage = false console.error('newMessage error:', error) @@ -521,12 +568,12 @@ export default { } }, - async runGenerationTask() { + async runGenerationTask(sessionId) { try { this.loading.llmGeneration = true - const response = await axios.get(getChatURL('/generate'), { params: { sessionId: this.active.id } }) + const response = await axios.get(getChatURL('/generate'), { params: { sessionId } }) console.debug('scheduleGenerationTask response:', response) - const message = await this.pollGenerationTask(response.data.taskId) + const message = await this.pollGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', message) this.messages.push(message) this.scrollToBottom() @@ -540,10 +587,11 @@ export default { async runRegenerationTask(messageId) { try { + const sessionId = this.active.id this.loading.llmGeneration = true - const response = await axios.get(getChatURL('/regenerate'), { params: { messageId, sessionId: this.active.id } }) + const response = await axios.get(getChatURL('/regenerate'), { params: { messageId, sessionId } }) console.debug('scheduleRegenerationTask response:', response) - const message = await this.pollGenerationTask(response.data.taskId) + const message = await this.pollGenerationTask(response.data.taskId, sessionId) console.debug('checkTaskPolling result:', message) this.messages[this.messages.length - 1] = message this.scrollToBottom() @@ -555,15 +603,25 @@ export default { } }, - async pollGenerationTask(taskId) { + async pollGenerationTask(taskId, sessionId) { return new Promise((resolve, reject) => { this.pollMessageGenerationTimerId = setInterval(() => { + if (sessionId !== this.active.id) { + console.debug('Stop polling messages for session ' + sessionId + ' because it is not selected anymore') + clearInterval(this.pollMessageGenerationTimerId) + return + } axios.get( getChatURL('/check_generation'), - { params: { taskId, sessionId: this.active.id } }, + { params: { taskId, sessionId } }, ).then(response => { clearInterval(this.pollMessageGenerationTimerId) - resolve(response.data) + if (sessionId === this.active.id) { + resolve(response.data) + } else { + console.debug('Ignoring received message for session ' + sessionId + ' that is not selected anymore') + // should we reject here? + } }).catch(error => { // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) if (error.response?.status !== 417) { @@ -571,22 +629,32 @@ export default { clearInterval(this.pollMessageGenerationTimerId) reject(new Error('Message generation task check failed')) } else { - console.debug('checkTaskPolling, task is still scheduled or running', error) + console.debug('checkTaskPolling, task is still scheduled or running') } }) }, 2000) }) }, - async pollTitleGenerationTask(taskId) { + async pollTitleGenerationTask(taskId, sessionId) { return new Promise((resolve, reject) => { this.pollTitleGenerationTimerId = setInterval(() => { + if (sessionId !== this.active.id) { + console.debug('Stop polling title for session ' + sessionId + ' because it is not selected anymore') + clearInterval(this.pollTitleGenerationTimerId) + return + } axios.get( getChatURL('/check_title_generation'), - { params: { taskId, sessionId: this.active.id } }, + { params: { taskId, sessionId } }, ).then(response => { + if (sessionId === this.active.id) { + resolve(response) + } else { + console.debug('Ignoring received title for session ' + sessionId + ' that is not selected anymore') + // should we reject here? + } clearInterval(this.pollTitleGenerationTimerId) - resolve(response) }).catch(error => { // do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417) if (error.response?.status !== 417) { @@ -594,7 +662,7 @@ export default { clearInterval(this.pollTitleGenerationTimerId) reject(new Error('Title generation task check failed')) } else { - console.debug('checkTaskPolling, task is still scheduled or running', error) + console.debug('checkTaskPolling, task is still scheduled or running') } }) }, 2000) diff --git a/tests/psalm-baseline.xml b/tests/psalm-baseline.xml index 1b3c8054..a1dc3d18 100644 --- a/tests/psalm-baseline.xml +++ b/tests/psalm-baseline.xml @@ -1,5 +1,5 @@ - + - $content], Application::APP_ID . ':chatty-llm', $this->userId)]]> + $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId)]]>