diff --git a/appinfo/routes.php b/appinfo/routes.php index c5319f4..9bca5a8 100644 --- a/appinfo/routes.php +++ b/appinfo/routes.php @@ -13,5 +13,8 @@ 'routes' => [ ['name' => 'config#setConfig', 'url' => '/config', 'verb' => 'PUT'], ['name' => 'config#setAdminConfig', 'url' => '/admin-config', 'verb' => 'PUT'], + + ['name' => 'provider#getProviders', 'url' => '/providers', 'verb' => 'GET'], + ['name' => 'provider#getDefaultProviderKey', 'url' => '/default-provider-key', 'verb' => 'GET'], ], ]; diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index b547e86..44a17f9 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -11,9 +11,9 @@ use OCA\ContextChat\Listener\AppDisableListener; use OCA\ContextChat\Listener\FileListener; +use OCA\ContextChat\Service\ProviderService; use OCA\ContextChat\TextProcessing\ContextChatProvider; use OCA\ContextChat\TextProcessing\FreePromptProvider; -use OCA\ContextChat\TextProcessing\ScopedContextChatProvider; use OCP\App\Events\AppDisableEvent; use OCP\AppFramework\App; use OCP\AppFramework\Bootstrap\IBootContext; @@ -58,8 +58,13 @@ class Application extends App implements IBootstrap { 'text/org', ]; + private ProviderService $providerService; + public function __construct(array $urlParams = []) { parent::__construct(self::APP_ID, $urlParams); + + $container = $this->getContainer(); + $this->providerService = $container->get(ProviderService::class); } public function register(IRegistrationContext $context): void { @@ -73,7 +78,9 @@ public function register(IRegistrationContext $context): void { $context->registerEventListener(AppDisableEvent::class, AppDisableListener::class); $context->registerTextProcessingProvider(ContextChatProvider::class); $context->registerTextProcessingProvider(FreePromptProvider::class); - $context->registerTextProcessingProvider(ScopedContextChatProvider::class); + + /** @psalm-suppress ArgumentTypeCoercion, UndefinedClass */ + $this->providerService->updateProvider('files', 'default', '', true); } public function boot(IBootContext $context): void { diff --git a/lib/BackgroundJobs/IndexerJob.php b/lib/BackgroundJobs/IndexerJob.php index c3f0d0b..2b15e0a 100644 --- a/lib/BackgroundJobs/IndexerJob.php +++ b/lib/BackgroundJobs/IndexerJob.php @@ -9,6 +9,7 @@ use OCA\ContextChat\Db\QueueFile; use OCA\ContextChat\Service\LangRopeService; +use OCA\ContextChat\Service\ProviderService; use OCA\ContextChat\Service\QueueService; use OCA\ContextChat\Service\StorageService; use OCA\ContextChat\Type\Source; @@ -125,7 +126,15 @@ protected function index(array $files): void { $userIds = $this->storageService->getUsersForFileId($queueFile->getFileId()); foreach ($userIds as $userId) { try { - $source = new Source($userId, 'file: ' . $file->getId(), $file->getPath(), $fileHandle, $file->getMtime(), $file->getMimeType(), 'file'); + $source = new Source( + $userId, + ProviderService::getSourceId($file->getId()), + $file->getPath(), + $fileHandle, + $file->getMtime(), + $file->getMimeType(), + ProviderService::getDefaultProviderKey(), + ); } catch (InvalidPathException|NotFoundException $e) { $this->logger->error('Could not find file ' . $file->getPath(), ['exception' => $e]); continue 2; diff --git a/lib/BackgroundJobs/InitialContentImportJob.php b/lib/BackgroundJobs/InitialContentImportJob.php index 50b7f92..e7c6cf4 100644 --- a/lib/BackgroundJobs/InitialContentImportJob.php +++ b/lib/BackgroundJobs/InitialContentImportJob.php @@ -14,7 +14,7 @@ namespace OCA\ContextChat\BackgroundJobs; use OCA\ContextChat\Public\IContentProvider; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCP\App\IAppManager; use OCP\AppFramework\Utility\ITimeFactory; use OCP\BackgroundJob\QueuedJob; @@ -27,7 +27,7 @@ class InitialContentImportJob extends QueuedJob { public function __construct( private IAppManager $appManager, - private ProviderConfigService $configService, + private ProviderService $providerService, private LoggerInterface $logger, private IUserManager $userMan, ITimeFactory $timeFactory, @@ -57,8 +57,8 @@ protected function run($argument): void { return; } - $registeredProviders = $this->configService->getProviders(); - $identifier = ProviderConfigService::getConfigKey($providerObj->getAppId(), $providerObj->getId()); + $registeredProviders = $this->providerService->getProviders(); + $identifier = $this->providerService->getConfigKey($providerObj->getAppId(), $providerObj->getId()); if (!isset($registeredProviders[$identifier]) || $registeredProviders[$identifier]['isInitiated'] ) { @@ -66,6 +66,6 @@ protected function run($argument): void { } $providerObj->triggerInitialImport(); - $this->configService->updateProvider($providerObj->getAppId(), $providerObj->getId(), $argument, true); + $this->providerService->updateProvider($providerObj->getAppId(), $providerObj->getId(), $argument, true); } } diff --git a/lib/BackgroundJobs/SubmitContentJob.php b/lib/BackgroundJobs/SubmitContentJob.php index 8916199..cf2533e 100644 --- a/lib/BackgroundJobs/SubmitContentJob.php +++ b/lib/BackgroundJobs/SubmitContentJob.php @@ -16,7 +16,7 @@ use OCA\ContextChat\Db\QueueContentItem; use OCA\ContextChat\Db\QueueContentItemMapper; use OCA\ContextChat\Service\LangRopeService; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCA\ContextChat\Type\Source; use OCP\AppFramework\Utility\ITimeFactory; use OCP\BackgroundJob\IJobList; @@ -58,8 +58,8 @@ protected function run($argument): void { foreach ($bucketed as $userId => $entities) { $sources = array_map(function (QueueContentItem $item) use ($userId) { - $providerKey = ProviderConfigService::getConfigKey($item->getAppId(), $item->getProviderId()); - $sourceId = $providerKey . ': ' . $item->getItemId(); + $providerKey = ProviderService::getConfigKey($item->getAppId(), $item->getProviderId()); + $sourceId = ProviderService::getSourceId($item->getItemId(), $providerKey); return new Source( $userId, $sourceId, diff --git a/lib/Command/Prompt.php b/lib/Command/Prompt.php index 6924fc9..109bf82 100644 --- a/lib/Command/Prompt.php +++ b/lib/Command/Prompt.php @@ -13,7 +13,6 @@ namespace OCA\ContextChat\Command; use OCA\ContextChat\TextProcessing\ContextChatTaskType; -use OCA\ContextChat\TextProcessing\ScopedContextChatTaskType; use OCA\ContextChat\Type\ScopeType; use OCP\TextProcessing\FreePromptTaskType; use OCP\TextProcessing\IManager; @@ -80,26 +79,30 @@ protected function execute(InputInterface $input, OutputInterface $output) { throw new \InvalidArgumentException('Cannot use --context-sources with --context-provider'); } - if ($noContext) { - $task = new Task(FreePromptTaskType::class, $prompt, 'context_chat', $userId); - } elseif (!empty($contextSources)) { - $contextSources = preg_replace('/\s*,+\s*/', ',', $contextSources); - $contextSourcesArray = array_filter(explode(',', $contextSources), fn ($source) => !empty($source)); - $task = new Task(ScopedContextChatTaskType::class, json_encode([ - 'scopeType' => ScopeType::SOURCE, - 'scopeList' => $contextSourcesArray, - 'prompt' => $prompt, - ]), 'context_chat', $userId); - } elseif (!empty($contextProviders)) { - $contextProviders = preg_replace('/\s*,+\s*/', ',', $contextProviders); - $contextProvidersArray = array_filter(explode(',', $contextProviders), fn ($source) => !empty($source)); - $task = new Task(ScopedContextChatTaskType::class, json_encode([ - 'scopeType' => ScopeType::PROVIDER, - 'scopeList' => $contextProvidersArray, - 'prompt' => $prompt, - ]), 'context_chat', $userId); - } else { - $task = new Task(ContextChatTaskType::class, $prompt, 'context_chat', $userId); + try { + if ($noContext) { + $task = new Task(FreePromptTaskType::class, $prompt, 'context_chat', $userId); + } elseif (!empty($contextSources)) { + $contextSources = preg_replace('/\s*,+\s*/', ',', $contextSources); + $contextSourcesArray = array_filter(explode(',', $contextSources), fn ($source) => !empty($source)); + $task = new Task(ContextChatTaskType::class, json_encode([ + 'scopeType' => ScopeType::SOURCE, + 'scopeList' => $contextSourcesArray, + 'prompt' => $prompt, + ], JSON_THROW_ON_ERROR), 'context_chat', $userId); + } elseif (!empty($contextProviders)) { + $contextProviders = preg_replace('/\s*,+\s*/', ',', $contextProviders); + $contextProvidersArray = array_filter(explode(',', $contextProviders), fn ($source) => !empty($source)); + $task = new Task(ContextChatTaskType::class, json_encode([ + 'scopeType' => ScopeType::PROVIDER, + 'scopeList' => $contextProvidersArray, + 'prompt' => $prompt, + ], JSON_THROW_ON_ERROR), 'context_chat', $userId); + } else { + $task = new Task(ContextChatTaskType::class, json_encode([ 'prompt' => $prompt ], JSON_THROW_ON_ERROR), 'context_chat', $userId); + } + } catch (\JsonException $e) { + throw new \InvalidArgumentException('Invalid input, cannot encode JSON', intval($e->getCode()), $e); } $this->textProcessingManager->runTask($task); diff --git a/lib/Command/ScanFiles.php b/lib/Command/ScanFiles.php index 7c6b6ed..e972410 100644 --- a/lib/Command/ScanFiles.php +++ b/lib/Command/ScanFiles.php @@ -44,6 +44,11 @@ protected function execute(InputInterface $input, OutputInterface $output) { ? explode(',', $input->getOption('mimetype')) : Application::MIMETYPES; + if ($mimeTypeFilter === false) { + $output->writeln('Invalid mime type filter'); + return 1; + } + $userId = $input->getArgument('user_id'); $scan = $this->scanService->scanUserFiles($userId, $mimeTypeFilter); foreach ($scan as $s) { diff --git a/lib/Controller/ProviderController.php b/lib/Controller/ProviderController.php new file mode 100644 index 0000000..2600a9a --- /dev/null +++ b/lib/Controller/ProviderController.php @@ -0,0 +1,58 @@ + + * + * @author Anupam Kumar + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +namespace OCA\ContextChat\Controller; + +use OCA\ContextChat\Service\ProviderService; +use OCP\AppFramework\Controller; +use OCP\AppFramework\Http\Attribute\NoAdminRequired; +use OCP\AppFramework\Http\DataResponse; +use OCP\IRequest; + +class ProviderController extends Controller { + + public function __construct( + string $appName, + IRequest $request, + private ProviderService $providerService, + ) { + parent::__construct($appName, $request); + } + + /** + * @return DataResponse + */ + #[NoAdminRequired] + public function getDefaultProviderKey(): DataResponse { + $providerKey = $this->providerService->getDefaultProviderKey(); + return new DataResponse($providerKey); + } + + /** + * @return DataResponse + */ + #[NoAdminRequired] + public function getProviders(): DataResponse { + $providers = $this->providerService->getEnrichedProviders(); + return new DataResponse($providers); + } +} diff --git a/lib/Listener/AppDisableListener.php b/lib/Listener/AppDisableListener.php index c94c9a8..21184d8 100644 --- a/lib/Listener/AppDisableListener.php +++ b/lib/Listener/AppDisableListener.php @@ -14,7 +14,7 @@ namespace OCA\ContextChat\Listener; use OCA\ContextChat\Service\LangRopeService; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCP\App\Events\AppDisableEvent; use OCP\EventDispatcher\Event; use OCP\EventDispatcher\IEventListener; @@ -25,7 +25,7 @@ */ class AppDisableListener implements IEventListener { public function __construct( - private ProviderConfigService $configService, + private ProviderService $providerService, private LangRopeService $service, private LoggerInterface $logger, ) { @@ -36,7 +36,7 @@ public function handle(Event $event): void { return; } - foreach ($this->configService->getProviders() as $key => $values) { + foreach ($this->providerService->getProviders() as $key => $values) { /** @var string[] */ $identifierValues = explode('__', $key, 2); @@ -51,7 +51,7 @@ public function handle(Event $event): void { continue; } - $this->configService->removeProvider($appId, $providerId); + $this->providerService->removeProvider($appId, $providerId); $this->service->deleteSourcesByProviderForAllUsers($providerId); } } diff --git a/lib/Listener/FileListener.php b/lib/Listener/FileListener.php index 26c39e2..21979ff 100644 --- a/lib/Listener/FileListener.php +++ b/lib/Listener/FileListener.php @@ -10,6 +10,7 @@ use OCA\ContextChat\AppInfo\Application; use OCA\ContextChat\Db\QueueFile; use OCA\ContextChat\Service\LangRopeService; +use OCA\ContextChat\Service\ProviderService; use OCA\ContextChat\Service\QueueService; use OCA\ContextChat\Service\StorageService; use OCP\DB\Exception; @@ -120,7 +121,7 @@ public function handle(Event $event): void { if (!$node instanceof File) { continue; } - $fileRefs[] = 'file: ' . $node->getId(); + $fileRefs[] = ProviderService::getSourceId($node->getId()); } $this->langRopeService->deleteSources($userId, $fileRefs); @@ -129,7 +130,7 @@ public function handle(Event $event): void { return; } - $fileRef = 'file: ' . $node->getId(); + $fileRef = ProviderService::getSourceId($node->getId()); foreach ($userIds as $userId) { $this->langRopeService->deleteSources($userId, [$fileRef]); } @@ -194,7 +195,7 @@ public function postDelete(Node $node, bool $recurse = true): void { } foreach ($this->storageService->getUsersForFileId($node->getId()) as $userId) { - $fileRef = 'file: ' . $node->getId(); + $fileRef = ProviderService::getSourceId($node->getId()); $this->langRopeService->deleteSources($userId, [$fileRef]); } } diff --git a/lib/Public/ContentManager.php b/lib/Public/ContentManager.php index ca02c78..1023ec3 100644 --- a/lib/Public/ContentManager.php +++ b/lib/Public/ContentManager.php @@ -16,7 +16,7 @@ use OCA\ContextChat\Db\QueueContentItem; use OCA\ContextChat\Db\QueueContentItemMapper; use OCA\ContextChat\Service\LangRopeService; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCP\BackgroundJob\IJobList; use OCP\Server; use Psr\Container\ContainerExceptionInterface; @@ -26,7 +26,7 @@ class ContentManager { public function __construct( private IJobList $jobList, - private ProviderConfigService $configService, + private ProviderService $providerService, private LangRopeService $service, private QueueContentItemMapper $mapper, private LoggerInterface $logger, @@ -47,11 +47,11 @@ public function registerContentProvider(string $providerClass): void { return; } - if ($this->configService->hasProvider($providerObj->getAppId(), $providerObj->getId())) { + if ($this->providerService->hasProvider($providerObj->getAppId(), $providerObj->getId())) { return; } - $this->configService->updateProvider($providerObj->getAppId(), $providerObj->getId(), $providerClass); + $this->providerService->updateProvider($providerObj->getAppId(), $providerObj->getId(), $providerClass); if (!$this->jobList->has(InitialContentImportJob::class, $providerClass)) { $this->jobList->add(InitialContentImportJob::class, $providerClass); @@ -98,7 +98,9 @@ public function submitContent(string $appId, array $items): void { */ public function removeContentForUsers(string $appId, string $providerId, string $itemId, array $users): void { foreach ($users as $userId) { - $this->service->deleteSources($userId, [$this->configService->getConfigKey($appId, $providerId) . ": $itemId"]); + $this->service->deleteSources($userId, [ + $this->providerService->getSourceId($itemId, $this->providerService->getConfigKey($appId, $providerId)) + ]); } } @@ -112,7 +114,7 @@ public function removeContentForUsers(string $appId, string $providerId, string */ public function removeAllContentForUsers(string $appId, string $providerId, array $users): void { foreach ($users as $userId) { - $this->service->deleteSourcesByProvider($userId, $this->configService->getConfigKey($appId, $providerId)); + $this->service->deleteSourcesByProvider($userId, $this->providerService->getConfigKey($appId, $providerId)); } } } diff --git a/lib/Service/LangRopeService.php b/lib/Service/LangRopeService.php index 82517f1..2d6a073 100644 --- a/lib/Service/LangRopeService.php +++ b/lib/Service/LangRopeService.php @@ -35,11 +35,18 @@ public function __construct( private IAppManager $appManager, private IURLGenerator $urlGenerator, private IUserManager $userMan, - private ProviderConfigService $configService, private ?string $userId, ) { } + /** + * @param string $route + * @param string $method + * @param array $params + * @param string|null $contentType + * @return array + * @throws RuntimeException + */ private function requestToExApp( string $route, string $method = 'POST', @@ -121,6 +128,7 @@ private function requestToExApp( * @param string $userId * @param string $providerKey * @return void + * @throws RuntimeException */ public function deleteSourcesByProvider(string $userId, string $providerKey): void { $params = [ @@ -134,6 +142,7 @@ public function deleteSourcesByProvider(string $userId, string $providerKey): vo /** * @param string $providerKey * @return void + * @throws RuntimeException */ public function deleteSourcesByProviderForAllUsers(string $providerKey): void { $params = [ @@ -147,6 +156,7 @@ public function deleteSourcesByProviderForAllUsers(string $providerKey): void { * @param string $userId * @param string[] $sourceNames * @return void + * @throws RuntimeException */ public function deleteSources(string $userId, array $sourceNames): void { if (count($sourceNames) === 0) { @@ -164,6 +174,7 @@ public function deleteSources(string $userId, array $sourceNames): void { /** * @param Source[] $sources * @return void + * @throws RuntimeException */ public function indexSources(array $sources): void { if (count($sources) === 0) { @@ -173,14 +184,14 @@ public function indexSources(array $sources): void { $params = array_map(function (Source $source) { return [ 'name' => 'sources', - 'filename' => $source->reference, // eg. 'file: 555' + 'filename' => $source->reference, // eg. 'files__default: 555' 'contents' => $source->content, 'headers' => [ 'userId' => $source->userId, 'title' => $source->title, 'type' => $source->type, 'modified' => $source->modified, - 'provider' => $source->provider, // eg. 'file' + 'provider' => $source->provider, // eg. 'files__default' ], ]; }, $sources); @@ -188,33 +199,28 @@ public function indexSources(array $sources): void { $this->requestToExApp('/loadSources', 'PUT', $params, 'multipart/form-data'); } - public function query(string $userId, string $prompt, bool $useContext = true): array { - $params = [ - 'query' => $prompt, - 'userId' => $userId, - 'useContext' => $useContext, - ]; - - $response = $this->requestToExApp('/query', 'GET', $params); - return ['message' => $this->getWithPresentableSources($response['output'] ?? '', ...($response['sources'] ?? []))]; - } - /** * @param string $userId * @param string $prompt - * @param string $scopeType - * @param array $scopeList + * @param bool $useContext + * @param ?string $scopeType + * @param ?array $scopeList * @return array + * @throws RuntimeException */ - public function scopedQuery(string $userId, string $prompt, string $scopeType, array $scopeList): array { + public function query(string $userId, string $prompt, bool $useContext = true, ?string $scopeType = null, ?array $scopeList = null): array { $params = [ 'query' => $prompt, 'userId' => $userId, - 'scopeType' => $scopeType, - 'scopeList' => $scopeList, + 'useContext' => $useContext, ]; + if ($scopeType !== null && $scopeList !== null) { + $params['useContext'] = true; + $params['scopeType'] = $scopeType; + $params['scopeList'] = $scopeList; + } - $response = $this->requestToExApp('/scopedQuery', 'POST', $params); + $response = $this->requestToExApp('/query', 'POST', $params); return ['message' => $this->getWithPresentableSources($response['output'] ?? '', ...($response['sources'] ?? []))]; } @@ -225,8 +231,9 @@ public function getWithPresentableSources(string $llmResponse, string ...$source $output = str_repeat(PHP_EOL, 3) . $this->l10n->t('Sources referenced to generate the above response:') . PHP_EOL; + $emptyFilesSourceId = ProviderService::getSourceId(''); foreach ($sourceRefs as $source) { - if (str_starts_with($source, 'file: ') && is_numeric($fileId = substr($source, 6))) { + if (str_starts_with($source, $emptyFilesSourceId) && is_numeric($fileId = substr($source, strlen($emptyFilesSourceId)))) { // use `overwritehost` setting in config.php to overwrite the host $output .= $this->urlGenerator->linkToRouteAbsolute('files.View.showFile', ['fileid' => $fileId]) . PHP_EOL; } elseif (str_contains($source, '__')) { diff --git a/lib/Service/ProviderConfigService.php b/lib/Service/ProviderConfigService.php index 6564c3f..719b4e0 100644 --- a/lib/Service/ProviderConfigService.php +++ b/lib/Service/ProviderConfigService.php @@ -55,6 +55,7 @@ public function getProviders(): array { if ($providers === null || !$this->validateProvidersArray($providers)) { $providers = []; + $this->config->setAppValue(Application::APP_ID, 'providers', ''); } } @@ -92,7 +93,7 @@ public function removeProvider(string $appId, ?string $providerId = null): void unset($providers[$this->getConfigKey($appId, $providerId)]); } elseif ($providerId === null) { foreach ($providers as $k => $v) { - if (str_starts_with($k, $appId)) { + if (str_starts_with($k, "{$appId}__")) { unset($providers[$k]); } } diff --git a/lib/Service/ProviderService.php b/lib/Service/ProviderService.php new file mode 100644 index 0000000..7237af6 --- /dev/null +++ b/lib/Service/ProviderService.php @@ -0,0 +1,91 @@ + + */ + public function getEnrichedProviders(): array { + $providers = $this->providerService->getProviders(); + $sanitizedProviders = []; + + foreach ($providers as $providerKey => $metadata) { + // providerKey ($appId__$providerId) + /** @var string[] */ + $providerValues = explode('__', $providerKey, 2); + + if (count($providerValues) !== 2) { + $this->logger->info("Invalid provider key $providerKey, skipping"); + continue; + } + + [$appId, $providerId] = $providerValues; + + $user = $this->userId === null ? null : $this->userManager->get($this->userId); + if (!$this->appManager->isEnabledForUser($appId, $user)) { + $this->logger->info("App $appId is not enabled for user {$this->userId}, skipping"); + continue; + } + + $appInfo = $this->appManager->getAppInfo($appId); + if ($appInfo === null) { + $this->logger->info("Could not get app info for $appId, skipping"); + continue; + } + + try { + $icon = $this->urlGenerator->imagePath($appId, 'app-dark.svg'); + } catch (\RuntimeException $e) { + $this->logger->info("Could not get app image for $appId"); + $icon = ''; + } + + if (!isset($appInfo['name'])) { + $this->logger->info("App $appId does not have a name, skipping"); + continue; + } + + $sanitizedProviders[] = [ + 'id' => $providerKey, + 'label' => ucfirst($providerId) . ' - ' . $appInfo['name'], + 'icon' => $icon, + ]; + } + return $sanitizedProviders; + } +} diff --git a/lib/Service/ScanService.php b/lib/Service/ScanService.php index 76178a2..3b0a448 100644 --- a/lib/Service/ScanService.php +++ b/lib/Service/ScanService.php @@ -48,10 +48,6 @@ public function scanDirectory(string $userId, array $mimeTypeFilter, Folder $dir $size = 0; foreach ($directory->getDirectoryListing() as $node) { if ($node instanceof File) { - if (!in_array($node->getMimeType(), $mimeTypeFilter)) { - continue; - } - $node_size = $node->getSize(); if ($size + $node_size > Application::CC_MAX_SIZE || count($sources) >= Application::CC_MAX_FILES) { @@ -60,22 +56,11 @@ public function scanDirectory(string $userId, array $mimeTypeFilter, Folder $dir $size = 0; } - try { - $fileHandle = $node->fopen('r'); - } catch (\Exception $e) { - $this->logger->error('Could not open file ' . $node->getPath() . ' for reading: ' . $e->getMessage()); + $source = $this->getSourceFromFile($userId, $mimeTypeFilter, $node); + if ($source === null) { continue; } - $source = new Source( - $userId, - 'file: ' . $node->getId(), - $node->getPath(), - $fileHandle, - $node->getMTime(), - $node->getMimeType(), - 'file' - ); $sources[] = $source; $size += $node_size; @@ -97,6 +82,30 @@ public function scanDirectory(string $userId, array $mimeTypeFilter, Folder $dir return []; } + public function getSourceFromFile(string $userId, array $mimeTypeFilter, File $node): Source | null { + if (!in_array($node->getMimeType(), $mimeTypeFilter)) { + return null; + } + + try { + $fileHandle = $node->fopen('r'); + } catch (\Exception $e) { + $this->logger->error('Could not open file ' . $node->getPath() . ' for reading: ' . $e->getMessage()); + return null; + } + + $providerKey = ProviderService::getDefaultProviderKey(); + return new Source( + $userId, + $providerKey . ': ' . $node->getId(), + $node->getPath(), + $fileHandle, + $node->getMTime(), + $node->getMimeType(), + $providerKey, + ); + } + public function indexSources(array $sources): void { $this->langRopeService->indexSources($sources); } diff --git a/lib/TextProcessing/ContextChatProvider.php b/lib/TextProcessing/ContextChatProvider.php index 13521ab..74311e7 100644 --- a/lib/TextProcessing/ContextChatProvider.php +++ b/lib/TextProcessing/ContextChatProvider.php @@ -3,10 +3,20 @@ declare(strict_types=1); namespace OCA\ContextChat\TextProcessing; +use OCA\ContextChat\AppInfo\Application; use OCA\ContextChat\Service\LangRopeService; +use OCA\ContextChat\Service\ProviderService; +use OCA\ContextChat\Service\ScanService; +use OCA\ContextChat\Type\ScopeType; +use OCP\Files\File; +use OCP\Files\Folder; +use OCP\Files\IRootFolder; +use OCP\Files\NotPermittedException; use OCP\IL10N; use OCP\TextProcessing\IProvider; use OCP\TextProcessing\IProviderWithUserId; +use Psr\Log\LoggerInterface; +use RuntimeException; /** * @template-implements IProviderWithUserId @@ -17,8 +27,11 @@ class ContextChatProvider implements IProvider, IProviderWithUserId { private ?string $userId = null; public function __construct( - private LangRopeService $langRopeService, private IL10N $l10n, + private IRootFolder $rootFolder, + private LoggerInterface $logger, + private LangRopeService $langRopeService, + private ScanService $scanService, ) { } @@ -26,18 +39,158 @@ public function getName(): string { return $this->l10n->t('Nextcloud Assistant Context Chat Provider'); } + /** + * Accepted scopeList formats: + * - "files__default: $fileId" + * - "$appId__$providerId" + * + * @param string $prompt JSON string with the following structure: + * { + * "scopeType": string, (optional key) + * "scopeList": list[string], (optional key) + * "prompt": string + * } + * + * @return string + */ public function process(string $prompt): string { if ($this->userId === null) { throw new \RuntimeException('User ID is required to process the prompt.'); } - $response = $this->langRopeService->query($this->userId, $prompt); + try { + $parsedData = json_decode($prompt, true, flags: JSON_THROW_ON_ERROR | JSON_INVALID_UTF8_IGNORE); + } catch (\JsonException $e) { + throw new \RuntimeException( + 'Invalid JSON string, expected { "prompt": string } or { "scopeType": string, "scopeList": list[string], "prompt": string }', + intval($e->getCode()), $e, + ); + } + + if (!isset($parsedData['prompt']) || !is_string($parsedData['prompt'])) { + throw new \RuntimeException('Invalid JSON string, expected "prompt" key with string value'); + } + + if (!isset($parsedData['scopeType']) || !isset($parsedData['scopeList'])) { + $response = $this->langRopeService->query($this->userId, $prompt); + if (isset($response['error'])) { + throw new \RuntimeException('No result in ContextChat response. ' . $response['error']); + } + return $response['message'] ?? ''; + } + + if (!is_string($parsedData['scopeType']) || !is_array($parsedData['scopeList'])) { + throw new \RuntimeException('Invalid JSON string, expected "scopeType" key with string value and "scopeList" key with array value'); + } + + try { + ScopeType::validate($parsedData['scopeType']); + } catch (\InvalidArgumentException $e) { + throw new \RuntimeException($e->getMessage(), intval($e->getCode()), $e); + } + + $scopeList = array_unique($parsedData['scopeList']); + if (count($scopeList) === 0) { + throw new \RuntimeException('No sources found'); + } + + // index sources before the query, not needed for providers + if ($parsedData['scopeType'] === ScopeType::SOURCE) { + $processedScopes = $this->indexFiles(...$parsedData['scopeList']); + $this->logger->debug('All valid files indexed, querying ContextChat', ['scopeType' => $parsedData['scopeType'], 'scopeList' => $processedScopes]); + } else { + $processedScopes = $scopeList; + $this->logger->debug('No need to index sources, querying ContextChat', ['scopeType' => $parsedData['scopeType'], 'scopeList' => $processedScopes]); + } + + $response = $this->langRopeService->query( + $this->userId, + $parsedData['prompt'], + true, + $parsedData['scopeType'], + $processedScopes, + ); + if (isset($response['error'])) { - throw new \RuntimeException('No result in ContextChat response. ' . $response['error']); + throw new \RuntimeException('No result in ContextChat response: ' . $response['error']); } + return $response['message'] ?? ''; } + /** + * @param array scopeList + * @return array List of indexed files + */ + private function indexFiles(string ...$scopeList): array { + $nodes = []; + $indexedFiles = []; + $providerKey = ProviderService::getDefaultProviderKey(); + + foreach ($scopeList as $scope) { + if (!str_contains($scope, $providerKey . ': ')) { + $this->logger->warning('Invalid source format, expected "sourceId: itemId"'); + continue; + } + + $nodeId = substr($scope, strlen($providerKey . ': ')); + + try { + $userFolder = $this->rootFolder->getUserFolder($this->userId); + } catch (NotPermittedException $e) { + $this->logger->warning('Could not get user folder for user ' . $this->userId . ': ' . $e->getMessage()); + continue; + } + $node = $userFolder->getById(intval($nodeId)); + if (count($node) === 0) { + $this->logger->warning('Could not find file/folder with ID ' . $nodeId . ', skipping'); + continue; + } + $node = $node[0]; + + if (!$node instanceof File && !$node instanceof Folder) { + $this->logger->warning('Invalid source type, expected file/folder'); + continue; + } + + $nodes[] = [ + 'scope' => $scope, + 'node' => $node, + 'path' => $node->getPath(), + ]; + } + + // remove subfolders + $filteredNodes = $nodes; + foreach ($nodes as $node) { + if ($node['node'] instanceof Folder) { + $filteredNodes = array_filter($filteredNodes, function ($n) use ($node) { + return !str_starts_with($n['path'], $node['path'] . DIRECTORY_SEPARATOR); + }); + $filteredNodes[] = $node; + } + } + + foreach ($filteredNodes as $node) { + try { + if ($node['node'] instanceof File) { + $source = $this->scanService->getSourceFromFile($this->userId, Application::MIMETYPES, $node['node']); + $this->scanService->indexSources([$source]); + $indexedFiles[] = $node['scope']; + } elseif ($node['node'] instanceof Folder) { + $indexedFiles = array_merge( + $indexedFiles, + iterator_to_array($this->scanService->scanDirectory($this->userId, Application::MIMETYPES, $node['node'])), + ); + } + } catch (RuntimeException $e) { + $this->logger->warning('Could not index file/folder with ID ' . $node['node']->getId() . ': ' . $e->getMessage()); + } + } + + return $indexedFiles; + } + public function getTaskType(): string { return ContextChatTaskType::class; } diff --git a/lib/TextProcessing/ScopedContextChatProvider.php b/lib/TextProcessing/ScopedContextChatProvider.php deleted file mode 100644 index 3768e92..0000000 --- a/lib/TextProcessing/ScopedContextChatProvider.php +++ /dev/null @@ -1,93 +0,0 @@ - - * @template-implements IProvider - */ -class ScopedContextChatProvider implements IProvider, IProviderWithUserId { - - private ?string $userId = null; - - public function __construct( - private LangRopeService $langRopeService, - private IL10N $l10n, - ) { - } - - public function getName(): string { - return $this->l10n->t('Nextcloud Assistant Scoped Context Chat Provider'); - } - - /** - * @param string $prompt JSON string with the following structure: - * { - * "scopeType": string, - * "scopeList": list[string], - * "prompt": string, - * } - * - * @return string - */ - public function process(string $prompt): string { - if ($this->userId === null) { - throw new \RuntimeException('User ID is required to process the prompt.'); - } - - try { - $parsedData = json_decode($prompt, true, flags: JSON_THROW_ON_ERROR | JSON_INVALID_UTF8_IGNORE); - } catch (\JsonException $e) { - throw new \RuntimeException( - 'Invalid JSON string, expected { "scopeType": string, "scopeList": list[string], "prompt": string }', - intval($e->getCode()), $e, - ); - } - - if ( - !is_array($parsedData) - || !isset($parsedData['scopeType']) - || !is_string($parsedData['scopeType']) - || !isset($parsedData['scopeList']) - || !is_array($parsedData['scopeList']) - || !isset($parsedData['prompt']) - || !is_string($parsedData['prompt']) - ) { - throw new \RuntimeException('Invalid JSON string, expected { "scopeType": string, "scopeList": list[string], "prompt": string }'); - } - - try { - ScopeType::validate($parsedData['scopeType']); - } catch (\InvalidArgumentException $e) { - throw new \RuntimeException($e->getMessage(), intval($e->getCode()), $e); - } - - $response = $this->langRopeService->scopedQuery( - $this->userId, - $parsedData['prompt'], - $parsedData['scopeType'], - $parsedData['scopeList'], - ); - - if (isset($response['error'])) { - throw new \RuntimeException('No result in ContextChat response. ' . $response['error']); - } - - return $response['message'] ?? ''; - } - - public function getTaskType(): string { - return ScopedContextChatTaskType::class; - } - - public function setUserId(?string $userId): void { - $this->userId = $userId; - } -} diff --git a/lib/TextProcessing/ScopedContextChatTaskType.php b/lib/TextProcessing/ScopedContextChatTaskType.php deleted file mode 100644 index fc204de..0000000 --- a/lib/TextProcessing/ScopedContextChatTaskType.php +++ /dev/null @@ -1,52 +0,0 @@ - - * - * @author Julien Veyssier - * - * @license GNU AGPL version 3 or any later version - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of the - * License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -namespace OCA\ContextChat\TextProcessing; - -use OCP\IL10N; -use OCP\TextProcessing\ITaskType; - -class ScopedContextChatTaskType implements ITaskType { - public function __construct( - private IL10N $l, - ) { - } - - /** - * @inheritDoc - * @since 27.1.0 - */ - public function getName(): string { - return $this->l->t('Scoped Context Chat'); - } - - /** - * @inheritDoc - * @since 27.1.0 - */ - public function getDescription(): string { - return $this->l->t('Ask a question about the data selected by you.'); - } -}