Skip to content

Commit

Permalink
adjust title generation using the same logic as message generation
Browse files Browse the repository at this point in the history
Signed-off-by: Julien Veyssier <[email protected]>
  • Loading branch information
julien-nc committed Nov 20, 2024
1 parent 06201f6 commit 829cc57
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 60 deletions.
2 changes: 1 addition & 1 deletion appinfo/routes.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
['name' => 'chattyLLM#getMessages', 'url' => '/chat/messages', 'verb' => 'GET'],
['name' => 'chattyLLM#generateForSession', 'url' => '/chat/generate', 'verb' => 'GET'],
['name' => 'chattyLLM#regenerateForSession', 'url' => '/chat/regenerate', 'verb' => 'GET'],
['name' => 'chattyLLM#checkMessageGenerationSession', 'url' => '/chat/check_session', 'verb' => 'GET'],
['name' => 'chattyLLM#checkSession', 'url' => '/chat/check_session', 'verb' => 'GET'],
['name' => 'chattyLLM#checkMessageGenerationTask', 'url' => '/chat/check_generation', 'verb' => 'GET'],
['name' => 'chattyLLM#generateTitle', 'url' => '/chat/generate_title', 'verb' => 'GET'],
['name' => 'chattyLLM#checkTitleGenerationTask', 'url' => '/chat/check_title_generation', 'verb' => 'GET'],
Expand Down
37 changes: 23 additions & 14 deletions lib/Controller/ChattyLLMController.php
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes
* @throws \OCP\DB\Exception
*/
#[NoAdminRequired]
public function checkMessageGenerationSession(int $sessionId): JSONResponse {
public function checkSession(int $sessionId): JSONResponse {
if ($this->userId === null) {
return new JSONResponse(['error' => $this->l10n->t('User not logged in')], Http::STATUS_UNAUTHORIZED);
}
Expand All @@ -414,22 +414,32 @@ public function checkMessageGenerationSession(int $sessionId): JSONResponse {
}

try {
$tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId);
$messageTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId);
$titleTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-title:' . $sessionId);
} catch (\OCP\TaskProcessing\Exception\Exception $e) {
return new JSONResponse(['error' => 'task_query_failed'], Http::STATUS_BAD_REQUEST);
}
$tasks = array_filter($tasks, static function (Task $task) {
$messageTasks = array_filter($messageTasks, static function (Task $task) {
return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED;
});
if (empty($tasks)) {
return new JSONResponse([
'taskId' => null,
]);
}
$task = array_pop($tasks);
return new JSONResponse([
'taskId' => $task->getId(),
]);
$titleTasks = array_filter($titleTasks, static function (Task $task) {
return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED;
});
$session = $this->sessionMapper->getUserSession($this->userId, $sessionId);
$responseData = [
'messageTaskId' => null,
'titleTaskId' => null,
'sessionTitle' => $session->getTitle(),
];
if (!empty($messageTasks)) {
$task = array_pop($messageTasks);
$responseData['messageTaskId'] = $task->getId();
}
if (!empty($titleTasks)) {
$task = array_pop($titleTasks);
$responseData['titleTaskId'] = $task->getId();
}
return new JSONResponse($responseData);
}

/**
Expand Down Expand Up @@ -523,8 +533,7 @@ public function checkTitleGenerationTask(int $taskId, int $sessionId): JSONRespo
$title = str_replace('"', '', $title);
$title = explode(PHP_EOL, $title)[0];
$title = trim($title);

$this->sessionMapper->updateSessionTitle($this->userId, $sessionId, $title);
// do not write the title here since it's done in the listener

return new JSONResponse(['result' => $title]);
} catch (\OCP\DB\Exception $e) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Db/ChattyLLM/Session.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
/**
* @method \string getUserId()
* @method \void setUserId(string $userId)
* @method \?string getTitle()
* @method \string|null getTitle()
* @method \void setTitle(?string $title)
* @method \int|null getTimestamp()
* @method \void setTimestamp(?int $timestamp)
Expand Down
21 changes: 21 additions & 0 deletions lib/Db/ChattyLLM/SessionMapper.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@

namespace OCA\Assistant\Db\ChattyLLM;

use OCP\AppFramework\Db\DoesNotExistException;
use OCP\AppFramework\Db\MultipleObjectsReturnedException;
use OCP\AppFramework\Db\QBMapper;
use OCP\DB\Exception;
use OCP\DB\QueryBuilder\IQueryBuilder;
use OCP\IDBConnection;

Expand Down Expand Up @@ -59,6 +62,24 @@ public function exists(string $userId, int $sessionId): bool {
}
}

/**
* @param string $userId
* @param int $sessionId
* @return Session
* @throws DoesNotExistException
* @throws MultipleObjectsReturnedException
* @throws Exception
*/
public function getUserSession(string $userId, int $sessionId): Session {
$qb = $this->db->getQueryBuilder();
$qb->select('id', 'title', 'timestamp')
->from($this->getTableName())
->where($qb->expr()->eq('id', $qb->createPositionalParameter($sessionId, IQueryBuilder::PARAM_INT)))
->andWhere($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId, IQueryBuilder::PARAM_STR)));

return $this->findEntity($qb);
}

/**
* @param string $userId
* @return array
Expand Down
41 changes: 26 additions & 15 deletions lib/Listener/ChattyLLMTaskListener.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use OCA\Assistant\AppInfo\Application;
use OCA\Assistant\Db\ChattyLLM\Message;
use OCA\Assistant\Db\ChattyLLM\MessageMapper;
use OCA\Assistant\Db\ChattyLLM\SessionMapper;
use OCP\EventDispatcher\Event;
use OCP\EventDispatcher\IEventListener;
use OCP\TaskProcessing\Events\TaskSuccessfulEvent;
Expand All @@ -19,6 +20,7 @@ class ChattyLLMTaskListener implements IEventListener {

public function __construct(
private MessageMapper $messageMapper,
private SessionMapper $sessionMapper,
private LoggerInterface $logger,
) {
}
Expand All @@ -31,23 +33,32 @@ public function handle(Event $event): void {
$task = $event->getTask();
$customId = $task->getCustomId();
$appId = $task->getAppId();
if ($appId !== (Application::APP_ID . ':chatty-llm')
|| $customId === null
|| !preg_match('/^chatty-llm:(\d+)$/', $customId, $matches)
) {

if ($customId === null || $appId !== (Application::APP_ID . ':chatty-llm')) {
return;
}
$sessionId = (int)$matches[1];

$message = new Message();
$message->setSessionId($sessionId);
$message->setRole('assistant');
$message->setContent(trim($task->getOutput()['output'] ?? ''));
$message->setTimestamp(time());
try {
$this->messageMapper->insert($message);
} catch (\OCP\DB\Exception $e) {
$this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]);

// title generation
if (preg_match('/^chatty-title:(\d+)$/', $customId, $matches)) {
$sessionId = (int)$matches[1];
$title = trim($task->getOutput()['output'] ?? '');
$this->sessionMapper->updateSessionTitle($task->getUserId(), $sessionId, $title);
}

// message generation
if (preg_match('/^chatty-llm:(\d+)$/', $customId, $matches)) {
$sessionId = (int)$matches[1];

$message = new Message();
$message->setSessionId($sessionId);
$message->setRole('assistant');
$message->setContent(trim($task->getOutput()['output'] ?? ''));
$message->setTimestamp(time());
try {
$this->messageMapper->insert($message);
} catch (\OCP\DB\Exception $e) {
$this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]);
}
}
}
}
90 changes: 61 additions & 29 deletions src/components/ChattyLLM/ChattyLLMInputForm.vue
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ export default {
watch: {
async active() {
this.loading.llmGeneration = false
this.loading.titleGeneration = false
this.allMessagesLoaded = false
this.chatContent = ''
this.msgCursor = 0
Expand All @@ -238,22 +239,45 @@ export default {
}
// start polling in case a message is currently being generated
try {
const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId: this.active.id } })
console.debug('check session response:', checkSessionResponse)
if (checkSessionResponse.data.taskId === null) {
return
const sessionId = this.active.id
const checkSessionResponse = await axios.get(getChatURL('/check_session'), { params: { sessionId } })
if (checkSessionResponse.data?.sessionTitle && checkSessionResponse.data?.sessionTitle !== this.active.title) {
this.active.title = checkSessionResponse.data?.sessionTitle
console.debug('update session title with check result')
}
try {
this.loading.llmGeneration = true
const message = await this.pollGenerationTask(checkSessionResponse.data.taskId, this.active.id)
console.debug('checkTaskPolling result:', message)
this.messages.push(message)
this.scrollToBottom()
} catch (error) {
console.error('checkGenerationTask error:', error)
showError(t('assistant', 'Error generating a response'))
} finally {
this.loading.llmGeneration = false
console.debug('check session response:', checkSessionResponse.data)
if (checkSessionResponse.data.messageTaskId !== null) {
try {
this.loading.llmGeneration = true
const message = await this.pollGenerationTask(checkSessionResponse.data.messageTaskId, sessionId)
console.debug('checkTaskPolling result:', message)
this.messages.push(message)
this.scrollToBottom()
} catch (error) {
console.error('checkGenerationTask error:', error)
showError(t('assistant', 'Error generating a response'))
} finally {
this.loading.llmGeneration = false
}
} else if (checkSessionResponse.data.titleTaskId !== null) {
try {
this.loading.titleGeneration = true
const titleResponse = await this.pollTitleGenerationTask(checkSessionResponse.data.titleTaskId, sessionId)
console.debug('checkTaskPolling result:', titleResponse)
if (titleResponse?.data?.result == null) {
throw new Error('No title generated, response:', titleResponse)
}
const session = this.sessions.find(s => s.id === sessionId)
if (session) {
session.title = titleResponse?.data?.result
}
} catch (error) {
console.error('onCheckSessionTitle error:', error)
showError(error?.response?.data?.error ?? t('assistant', 'Error getting the generated title for the conversation'))
} finally {
this.loading.titleGeneration = false
}
}
} catch (error) {
console.error('check session error:', error)
Expand Down Expand Up @@ -367,18 +391,17 @@ export default {
async onGenerateSessionTitle() {
try {
this.loading.titleGeneration = true
const response = await axios.get(getChatURL('/generate_title'), { params: { sessionId: this.active.id } })
const titleResponse = await this.pollTitleGenerationTask(response.data.taskId)
const sessionId = this.active.id
const response = await axios.get(getChatURL('/generate_title'), { params: { sessionId } })
const titleResponse = await this.pollTitleGenerationTask(response.data.taskId, sessionId)
console.debug('checkTaskPolling result:', titleResponse)
if (titleResponse?.data?.result == null) {
throw new Error('No title generated, response:', response)
}
for (const session of this.sessions) {
if (session.id === this.active.id) {
session.title = titleResponse?.data?.result
break
}
const session = this.sessions.find(s => s.id === sessionId)
if (session) {
session.title = titleResponse?.data?.result
}
} catch (error) {
console.error('onGenerateSessionTitle error:', error)
Expand Down Expand Up @@ -592,12 +615,12 @@ export default {
if (sessionId === this.active.id) {
resolve(response.data)
} else {
console.debug('Ignoring received a message for session ' + sessionId + ' that is not selected anymore')
console.debug('Ignoring received message for session ' + sessionId + ' that is not selected anymore')
// should we reject here?
}
}).catch(error => {
if (sessionId !== this.active.id) {
console.debug('Stop polling session ' + sessionId + ' because it is not selected anymore')
console.debug('Stop polling messages for session ' + sessionId + ' because it is not selected anymore')
clearInterval(this.pollMessageGenerationTimerId)
}
// do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417)
Expand All @@ -606,30 +629,39 @@ export default {
clearInterval(this.pollMessageGenerationTimerId)
reject(new Error('Message generation task check failed'))
} else {
console.debug('checkTaskPolling, task is still scheduled or running', error)
console.debug('checkTaskPolling, task is still scheduled or running')
}
})
}, 2000)
})
},
async pollTitleGenerationTask(taskId) {
async pollTitleGenerationTask(taskId, sessionId) {
return new Promise((resolve, reject) => {
this.pollTitleGenerationTimerId = setInterval(() => {
axios.get(
getChatURL('/check_title_generation'),
{ params: { taskId, sessionId: this.active.id } },
{ params: { taskId, sessionId } },
).then(response => {
if (sessionId === this.active.id) {
resolve(response)
} else {
console.debug('Ignoring received title for session ' + sessionId + ' that is not selected anymore')
// should we reject here?
}
clearInterval(this.pollTitleGenerationTimerId)
resolve(response)
}).catch(error => {
if (sessionId !== this.active.id) {
console.debug('Stop polling title for session ' + sessionId + ' because it is not selected anymore')
clearInterval(this.pollTitleGenerationTimerId)
}
// do not reject if response code is Http::STATUS_EXPECTATION_FAILED (417)
if (error.response?.status !== 417) {
console.error('checkTaskPolling error', error)
clearInterval(this.pollTitleGenerationTimerId)
reject(new Error('Title generation task check failed'))
} else {
console.debug('checkTaskPolling, task is still scheduled or running', error)
console.debug('checkTaskPolling, task is still scheduled or running')
}
})
}, 2000)
Expand Down

0 comments on commit 829cc57

Please sign in to comment.