Skip to content

Commit

Permalink
fix polling when switching sessions in the frontend, prevent scheduli…
Browse files Browse the repository at this point in the history
…ng multiple llm tasks for one session

Signed-off-by: Julien Veyssier <[email protected]>
  • Loading branch information
julien-nc committed Nov 20, 2024
1 parent 039a110 commit 06201f6
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 21 deletions.
1 change: 1 addition & 0 deletions appinfo/routes.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
['name' => 'chattyLLM#getMessages', 'url' => '/chat/messages', 'verb' => 'GET'],
['name' => 'chattyLLM#generateForSession', 'url' => '/chat/generate', 'verb' => 'GET'],
['name' => 'chattyLLM#regenerateForSession', 'url' => '/chat/regenerate', 'verb' => 'GET'],
['name' => 'chattyLLM#checkMessageGenerationSession', 'url' => '/chat/check_session', 'verb' => 'GET'],
['name' => 'chattyLLM#checkMessageGenerationTask', 'url' => '/chat/check_generation', 'verb' => 'GET'],
['name' => 'chattyLLM#generateTitle', 'url' => '/chat/generate_title', 'verb' => 'GET'],
['name' => 'chattyLLM#checkTitleGenerationTask', 'url' => '/chat/check_title_generation', 'verb' => 'GET'],
Expand Down
2 changes: 2 additions & 0 deletions lib/AppInfo/Application.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use OCA\Assistant\Capabilities;
use OCA\Assistant\Listener\BeforeTemplateRenderedListener;
use OCA\Assistant\Listener\ChattyLLMTaskListener;
use OCA\Assistant\Listener\CSPListener;
use OCA\Assistant\Listener\FreePrompt\FreePromptReferenceListener;
use OCA\Assistant\Listener\SpeechToText\SpeechToTextReferenceListener;
Expand Down Expand Up @@ -55,6 +56,7 @@ public function register(IRegistrationContext $context): void {

$context->registerEventListener(TaskSuccessfulEvent::class, TaskSuccessfulListener::class);
$context->registerEventListener(TaskFailedEvent::class, TaskFailedListener::class);
$context->registerEventListener(TaskSuccessfulEvent::class, ChattyLLMTaskListener::class);

$context->registerNotifierService(Notifier::class);

Expand Down
76 changes: 71 additions & 5 deletions lib/Controller/ChattyLLMController.php
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,11 @@ public function generateForSession(int $sessionId): JSONResponse {
. PHP_EOL
. 'assistant: ';

$taskId = $this->scheduleLLMTask($stichedPrompt);
try {
$taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId);
} catch (\Exception $e) {
return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST);
}

return new JSONResponse(['taskId' => $taskId]);
}
Expand Down Expand Up @@ -374,7 +378,7 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes
$message->setRole('assistant');
$message->setContent(trim($task->getOutput()['output'] ?? ''));
$message->setTimestamp(time());
$this->messageMapper->insert($message);
// do not insert here, it is done by the listener
return new JSONResponse($message);
} catch (\OCP\DB\Exception $e) {
$this->logger->warning('Failed to add a chat message into DB', ['exception' => $e]);
Expand All @@ -388,6 +392,46 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes
return new JSONResponse(['error' => 'unknown_error', 'task_status' => $task->getstatus()], Http::STATUS_BAD_REQUEST);
}

/**
* Check the status of a session
*
* Used by the frontend to determine if it should poll a generation task status.
*
* @param int $sessionId
* @return JSONResponse
* @throws \JsonException
* @throws \OCP\DB\Exception
*/
#[NoAdminRequired]
public function checkMessageGenerationSession(int $sessionId): JSONResponse {
if ($this->userId === null) {
return new JSONResponse(['error' => $this->l10n->t('User not logged in')], Http::STATUS_UNAUTHORIZED);
}

$sessionExists = $this->sessionMapper->exists($this->userId, $sessionId);
if (!$sessionExists) {
return new JSONResponse(['error' => $this->l10n->t('Session not found')], Http::STATUS_NOT_FOUND);
}

try {
$tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId);
} catch (\OCP\TaskProcessing\Exception\Exception $e) {
return new JSONResponse(['error' => 'task_query_failed'], Http::STATUS_BAD_REQUEST);
}
$tasks = array_filter($tasks, static function (Task $task) {
return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED;
});
if (empty($tasks)) {
return new JSONResponse([
'taskId' => null,
]);
}
$task = array_pop($tasks);
return new JSONResponse([
'taskId' => $task->getId(),
]);
}

/**
* Schedule a task to generate a title for the chat session
*
Expand Down Expand Up @@ -430,7 +474,11 @@ public function generateTitle(int $sessionId): JSONResponse {
. PHP_EOL . PHP_EOL
. $userInstructions;

$taskId = $this->scheduleLLMTask($stichedPrompt);
try {
$taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId, false);
} catch (\Exception $e) {
return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST);
}
return new JSONResponse(['taskId' => $taskId]);
} catch (\OCP\DB\Exception $e) {
$this->logger->warning('Failed to generate a title for the chat session', ['exception' => $e]);
Expand Down Expand Up @@ -525,14 +573,32 @@ private function getStichedMessages(int $sessionId): string {
* Schedule the LLM task
*
* @param string $content
* @param int $sessionId
* @param bool $isMessage
* @return int|null
* @throws Exception
* @throws PreConditionNotMetException
* @throws UnauthorizedException
* @throws ValidationException
* @throws \JsonException
*/
private function scheduleLLMTask(string $content): ?int {
$task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId);
private function scheduleLLMTask(string $content, int $sessionId, bool $isMessage = true): ?int {
$customId = ($isMessage
? 'chatty-llm:'
: 'chatty-title:') . $sessionId;
try {
$tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', $customId);
} catch (\OCP\TaskProcessing\Exception\Exception $e) {
throw new \Exception('task_query_failed');
}
$tasks = array_filter($tasks, static function (Task $task) {
return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED;
});
// prevent scheduling multiple llm tasks simultaneously for one session
if (!empty($tasks)) {
throw new \Exception('session_already_thinking');
}
$task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId);
$this->taskProcessingManager->scheduleTask($task);
return $task->getId();
}
Expand Down
53 changes: 53 additions & 0 deletions lib/Listener/ChattyLLMTaskListener.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
<?php

declare(strict_types=1);

namespace OCA\Assistant\Listener;

use OCA\Assistant\AppInfo\Application;
use OCA\Assistant\Db\ChattyLLM\Message;
use OCA\Assistant\Db\ChattyLLM\MessageMapper;
use OCP\EventDispatcher\Event;
use OCP\EventDispatcher\IEventListener;
use OCP\TaskProcessing\Events\TaskSuccessfulEvent;
use Psr\Log\LoggerInterface;

/**
* @template-implements IEventListener<TaskSuccessfulEvent>
*/
class ChattyLLMTaskListener implements IEventListener {

public function __construct(
private MessageMapper $messageMapper,
private LoggerInterface $logger,
) {
}

public function handle(Event $event): void {
if (!($event instanceof TaskSuccessfulEvent)) {
return;
}

$task = $event->getTask();
$customId = $task->getCustomId();
$appId = $task->getAppId();
if ($appId !== (Application::APP_ID . ':chatty-llm')
|| $customId === null
|| !preg_match('/^chatty-llm:(\d+)$/', $customId, $matches)
) {
return;
}
$sessionId = (int)$matches[1];

$message = new Message();
$message->setSessionId($sessionId);
$message->setRole('assistant');
$message->setContent(trim($task->getOutput()['output'] ?? ''));
$message->setTimestamp(time());
try {
$this->messageMapper->insert($message);
} catch (\OCP\DB\Exception $e) {
$this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]);
}
}
}
63 changes: 49 additions & 14 deletions src/components/ChattyLLM/ChattyLLMInputForm.vue
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,45 @@ export default {
watch: {
async active() {
this.loading.llmGeneration = false
this.allMessagesLoaded = false
this.chatContent = ''
this.msgCursor = 0
this.messages = []
this.editingTitle = false
this.$refs.inputComponent.focus()
if (this.active != null && !this.loading.newSession) {
if (this.active !== null && !this.loading.newSession) {
await this.fetchMessages()
this.scrollToBottom()
} else {
// when no active session or creating a new session
this.allMessagesLoaded = true
this.loading.newSession = false
return
}
// start polling in case a message is currently being generated
try {
const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId: this.active.id } })
console.debug('check session response:', checkSessionResponse)
if (checkSessionResponse.data.taskId === null) {
return
}
try {
this.loading.llmGeneration = true
const message = await this.pollGenerationTask(checkSessionResponse.data.taskId, this.active.id)
console.debug('checkTaskPolling result:', message)
this.messages.push(message)
this.scrollToBottom()
} catch (error) {
console.error('checkGenerationTask error:', error)
showError(t('assistant', 'Error generating a response'))
} finally {
this.loading.llmGeneration = false
}
} catch (error) {
console.error('check session error:', error)
showError(t('assistant', 'Error checking if the session is thinking'))
}
},
},
Expand Down Expand Up @@ -328,7 +353,7 @@ export default {
this.messages.push({ role, content, timestamp })
this.chatContent = ''
this.scrollToBottom()
await this.newMessage(role, content, timestamp)
await this.newMessage(role, content, timestamp, this.active.id)
},
onLoadOlderMessages() {
Expand Down Expand Up @@ -466,13 +491,13 @@ export default {
}
},
async newMessage(role, content, timestamp) {
async newMessage(role, content, timestamp, sessionId) {
try {
this.loading.newHumanMessage = true
const firstHumanMessage = this.messages.length === 1 && this.messages[0].role === Roles.HUMAN
const response = await axios.put(getChatURL('/new_message'), {
sessionId: this.active.id,
sessionId,
role,
content,
timestamp,
Expand All @@ -485,11 +510,11 @@ export default {
this.messages[this.messages.length - 1] = response.data
if (firstHumanMessage) {
const session = this.sessions.find((session) => session.id === this.active.id)
const session = this.sessions.find((session) => session.id === sessionId)
session.title = content
}
await this.runGenerationTask()
await this.runGenerationTask(sessionId)
} catch (error) {
this.loading.newHumanMessage = false
console.error('newMessage error:', error)
Expand Down Expand Up @@ -521,12 +546,12 @@ export default {
}
},
async runGenerationTask() {
async runGenerationTask(sessionId) {
try {
this.loading.llmGeneration = true
const response = await axios.get(getChatURL('/generate'), { params: { sessionId: this.active.id } })
const response = await axios.get(getChatURL('/generate'), { params: { sessionId } })
console.debug('scheduleGenerationTask response:', response)
const message = await this.pollGenerationTask(response.data.taskId)
const message = await this.pollGenerationTask(response.data.taskId, sessionId)
console.debug('checkTaskPolling result:', message)
this.messages.push(message)
this.scrollToBottom()
Expand All @@ -540,10 +565,11 @@ export default {
async runRegenerationTask(messageId) {
try {
const sessionId = this.active.id
this.loading.llmGeneration = true
const response = await axios.get(getChatURL('/regenerate'), { params: { messageId, sessionId: this.active.id } })
const response = await axios.get(getChatURL('/regenerate'), { params: { messageId, sessionId } })
console.debug('scheduleRegenerationTask response:', response)
const message = await this.pollGenerationTask(response.data.taskId)
const message = await this.pollGenerationTask(response.data.taskId, sessionId)
console.debug('checkTaskPolling result:', message)
this.messages[this.messages.length - 1] = message
this.scrollToBottom()
Expand All @@ -555,16 +581,25 @@ export default {
}
},
async pollGenerationTask(taskId) {
async pollGenerationTask(taskId, sessionId) {
return new Promise((resolve, reject) => {
this.pollMessageGenerationTimerId = setInterval(() => {
axios.get(
getChatURL('/check_generation'),
{ params: { taskId, sessionId: this.active.id } },
{ params: { taskId, sessionId } },
).then(response => {
clearInterval(this.pollMessageGenerationTimerId)
resolve(response.data)
if (sessionId === this.active.id) {
resolve(response.data)
} else {
console.debug('Ignoring received a message for session ' + sessionId + ' that is not selected anymore')
// should we reject here?
}
}).catch(error => {
if (sessionId !== this.active.id) {
console.debug('Stop polling session ' + sessionId + ' because it is not selected anymore')
clearInterval(this.pollMessageGenerationTimerId)
}
// do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417)
if (error.response?.status !== 417) {
console.error('checkTaskPolling error', error)
Expand Down
4 changes: 2 additions & 2 deletions tests/psalm-baseline.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
<files psalm-version="5.23.1@8471a896ccea3526b26d082f4461eeea467f10a4">
<files psalm-version="5.25.0@01a8eb06b9e9cc6cfb6a320bf9fb14331919d505">
<file src="lib/Controller/AssistantController.php">
<TooManyArguments>
<code><![CDATA[new Task(
Expand All @@ -13,7 +13,7 @@
</file>
<file src="lib/Controller/ChattyLLMController.php">
<TooManyArguments>
<code><![CDATA[new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId)]]></code>
<code><![CDATA[new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId)]]></code>
</TooManyArguments>
</file>
<file src="lib/Service/AssistantService.php">
Expand Down

0 comments on commit 06201f6

Please sign in to comment.