Skip to content

Commit

Permalink
Merge pull request #22 from krystianbajno/feature/shared-cache-memory…
Browse files Browse the repository at this point in the history
…-async

feature/shared-cache-memory-async
  • Loading branch information
krystianbajno authored Nov 24, 2024
2 parents a970a8c + 5deb025 commit de848fe
Show file tree
Hide file tree
Showing 22 changed files with 523 additions and 451 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Byte-compiled / optimized / DLL files
.DS_Store
cache/
dataset/
cveseeker_*_report.csv
cveseeker_*_report.json
cveseeker_*_report.html
Expand Down
4 changes: 2 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ providers:
enrichment:
sources:
vulners: true
github: false
github_cached: true
trickest_cve_github: false
trickest_cve_github_cached: true
cisa_kev: true
github_poc: false
github_poc_cached: true
Expand Down
19 changes: 14 additions & 5 deletions providers/search_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from typing import Dict

from services.cache.cache_manager import CacheManager
from services.search.engine.progress_factory import ProgressManagerFactory
from services.search.search_manager import SearchManager
from services.search.engine.progress_manager import ProgressManager
Expand Down Expand Up @@ -43,14 +44,22 @@ def boot(self):
config = self.load_config()
providers_config = config.get('providers', {})
enrichment_config = config.get("enrichment", {})


cache_manager = CacheManager(config)

providers = []

for provider_name, enabled in providers_config.items():
if enabled:
provider_class = self.provider_registry.get(provider_name)
if provider_class:
providers.append(provider_class())
if provider_name in [
'NistCachedAPI',
'CISAKEVAPI'
]:
providers.append(provider_class(cache_manager))
else:
providers.append(provider_class())
else:
print(f"[!] Provider '{provider_name}' not found in registry.")
else:
Expand All @@ -59,9 +68,9 @@ def boot(self):
if self.playwright_enabled:
playwright_providers = []
providers.extend(playwright_providers)

progress_manager_factory = ProgressManagerFactory()
self.search_service = SearchManager(providers, enrichment_config, progress_manager_factory=progress_manager_factory)
self.search_service = SearchManager(providers, enrichment_config, progress_manager_factory=progress_manager_factory, cache_manager=cache_manager)

def load_config(self) -> Dict:
try:
Expand Down
93 changes: 20 additions & 73 deletions services/api/sources/cisa_kev.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,35 @@
import os
import time
import json
from typing import List, Dict
import httpx
from dateutil import parser as dateutil_parser

from models.vulnerability import Vulnerability
from services.api.source import Source
from services.cache.cache_manager import CacheManager
from services.vulnerabilities.factories.vulnerability_factory import VulnerabilityFactory
from dateutil import parser as dateutil_parser
from typing import List

class CISAKEVAPI(Source):
CACHE_DIR = "cache"
CACHE_FILE = os.path.join(CACHE_DIR, "cisa_kev_cache.json")
CACHE_DURATION = 600 # 10 minutes
CISA_URL = "https://www.cisa.gov/sites/default/files/feeds/known_exploited_vulnerabilities.json"

def __init__(self):
self.ensure_cache_dir()

def ensure_cache_dir(self):
if not os.path.exists(self.CACHE_DIR):
try:
os.makedirs(self.CACHE_DIR)
print(f"Created cache directory at '{self.CACHE_DIR}'.")
except Exception as e:
print(f"[!] Failed to create cache directory '{self.CACHE_DIR}': {e}")

def is_cache_valid(self) -> bool:
if os.path.exists(self.CACHE_FILE):
cache_mtime = os.path.getmtime(self.CACHE_FILE)
current_time = time.time()
return (current_time - cache_mtime) < self.CACHE_DURATION
return False

def load_cache(self) -> Dict:
try:
with open(self.CACHE_FILE, 'r', encoding='utf-8') as f:
print("[*] Loaded CISA KEV catalog from cache.")
return json.load(f)
except Exception as e:
print(f"[!] Error reading cache file '{self.CACHE_FILE}': {e}")
return {}

def update_cache(self, data: Dict):
try:
with open(self.CACHE_FILE, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
print(f"[+] Cache file '{self.CACHE_FILE}' updated successfully.")
except Exception as e:
print(f"[!] Error updating cache file '{self.CACHE_FILE}': {e}")

def fetch_data(self) -> Dict:
try:
print("[*] Downloading CISA KEV catalog...")
response = httpx.get(self.CISA_URL, timeout=15)
if response.status_code == 200:
data = response.json()
self.update_cache(data)
return data
else:
print(f"[!] Failed to fetch CISA KEV catalog. Status code: {response.status_code}")
except Exception as e:
print(f"[!] Error fetching CISA KEV data: {e}")
return {}

def get_data(self) -> Dict:
if self.is_cache_valid():
return self.load_cache()
else:
return self.fetch_data()
def __init__(self, cache_manager: CacheManager):
self.cache_manager = cache_manager

def search(self, keywords: List[str], max_results: int = 10) -> List[Vulnerability]:
vulnerabilities = []

self.cache_manager.wait_for_data('cisa_kev')

data = self.cache_manager.get_data('cisa_kev')
if not data:
print("[!] CISA KEV data is not available.")
return []

try:
data = self.get_data()
kev_vulnerabilities = data.get("vulnerabilities", [])
keyword_set = set(keyword.lower() for keyword in keywords)
keyword_set = {keyword.lower() for keyword in keywords}

for item in kev_vulnerabilities:
cve_id = item.get("cveID")
if not cve_id:
continue

description = item.get("shortDescription", "N/A").lower()
if not any(keyword in description for keyword in keyword_set):
description = item.get("shortDescription", "N/A")
if not any(keyword in description.lower() for keyword in keyword_set):
continue

date_added = item.get("dateAdded")
Expand All @@ -97,6 +42,8 @@ def search(self, keywords: List[str], max_results: int = 10) -> List[Vulnerabili
notes = item.get("notes", "")
reference_urls = [url.strip() for url in notes.split(" ; ") if url.strip()]
weaknesses = item.get("cwes", [])
product = item.get("product", "N/A")
vendor_project = item.get("vendorProject", "N/A")

vulnerabilities.append(
VulnerabilityFactory.make(
Expand All @@ -105,9 +52,9 @@ def search(self, keywords: List[str], max_results: int = 10) -> List[Vulnerabili
url="https://www.cisa.gov/known-exploited-vulnerabilities-catalog",
date=date,
reference_urls=reference_urls,
description=item.get("shortDescription", "N/A"),
vulnerable_components=[item.get("product", "N/A")],
tags=[item.get("vendorProject", "N/A")],
description=description,
vulnerable_components=[product],
tags=[vendor_project],
weaknesses=weaknesses
)
)
Expand Down
14 changes: 9 additions & 5 deletions services/api/sources/nist_cached.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from models.vulnerability import Vulnerability
from services.search.nist_cache_manager import get_cve_data_cache
from services.cache.cache_manager import CacheManager
from services.vulnerabilities.factories.vulnerability_factory import VulnerabilityFactory, DEFAULT_VALUES
from dateutil import parser as dateutil_parser
from typing import List

class NistCachedAPI:
def search(self, keywords: List[str], max_results: int) -> List[Vulnerability]:
cve_data_cache = get_cve_data_cache()
def __init__(self, cache_manager: CacheManager):
self.cache_manager = cache_manager

def search(self, keywords: List[str], max_results: int) -> List[Vulnerability]:
self.cache_manager.wait_for_data('nist_cached')

cve_data_cache = self.cache_manager.get_data('nist_cached')
if not cve_data_cache:
print("[!] CVE data is not available. Returning empty results.")
print("[!] NIST CVE data is not available.")
return []

vulnerabilities = []
Expand Down Expand Up @@ -60,7 +64,7 @@ def search(self, keywords: List[str], max_results: int) -> List[Vulnerability]:
vulnerabilities.append(
VulnerabilityFactory.make(
id=cve_id,
url="https://github.com/fkie-cad/nvd-json-data-feeds/releases/latest/download/CVE-all.json.xz",
url="https://nvd.nist.gov/vuln/detail/" + cve_id,
source=self.__class__.__name__,
date=date,
reference_urls=reference_urls,
Expand Down
71 changes: 71 additions & 0 deletions services/cache/cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import threading
from typing import Dict, Callable, Optional

from services.cache.loaders.github_poc_data_loader import load_github_poc_data
from services.cache.loaders.nist_data_loader import load_nist_data
from services.cache.loaders.cisa_kev_data_loader import load_cisa_kev_data
from services.cache.loaders.trickest_cve_data_loader import load_trickest_cve_data

class CacheManager:
def __init__(self, config: Dict):
self.config = config
self.cache_data = {}
self.cache_events = {}
self.loading_threads = []
self.load_caches()

def load_caches(self):
providers_config = self.config.get('providers', {})
enrichment_config = self.config.get('enrichment', {}).get('sources', {})

provider_loaders = {
'NistCachedAPI': ('nist_cached', load_nist_data),
'CISAKEVAPI': ('cisa_kev', load_cisa_kev_data),
'GitHubCachedAPI': ('github_poc_cached', load_github_poc_data),
}

enrichment_loaders = {
'nist_cached': load_nist_data,
'cisa_kev': load_cisa_kev_data,
'github_poc_cached': load_github_poc_data,
'trickest_cve_github_cached': load_trickest_cve_data,
}

loaders_to_use = {}

for provider_name, enabled in providers_config.items():
if enabled and provider_name in provider_loaders:
cache_key, loader_func = provider_loaders[provider_name]
loaders_to_use[cache_key] = loader_func

for source_name, enabled in enrichment_config.items():
if enabled and source_name in enrichment_loaders:
if source_name not in loaders_to_use:
loaders_to_use[source_name] = enrichment_loaders[source_name]

for cache_key, loader_func in loaders_to_use.items():
self.cache_events[cache_key] = threading.Event()
thread = threading.Thread(target=self._load_data, args=(cache_key, loader_func))
self.loading_threads.append(thread)
thread.start()

def _load_data(self, name: str, loader_func: Callable):
data = loader_func()
with threading.Lock():
self.cache_data[name] = data
self.cache_events[name].set()

def is_data_ready(self, plugin_name: str) -> bool:
return self.cache_events.get(plugin_name, threading.Event()).is_set()

def wait_for_data(self, plugin_name: str, timeout: Optional[float] = None):
event = self.cache_events.get(plugin_name)
if event:
event.wait(timeout=timeout)

def get_data(self, plugin_name: str):
return self.cache_data.get(plugin_name)

def ensure_all_data_loaded(self):
for thread in self.loading_threads:
thread.join()
56 changes: 56 additions & 0 deletions services/cache/loaders/cisa_kev_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import time
import json
import httpx

from terminal.cli import print_greyed_out

def load_cisa_kev_data():
cache_dir = 'dataset'
cache_file = os.path.join(cache_dir, 'cisa_kev_cache.json')
cache_duration = 600 # 10 minutes

if not os.path.exists(cache_dir):
try:
os.makedirs(cache_dir)
print_greyed_out(f"[+] CISA_KEV_DATA_LOADER: Created cache directory at '{cache_dir}'.")
except Exception as e:
print_greyed_out(f"[!] CISA_KEV_DATA_LOADER: Failed to create cache directory '{cache_dir}': {e}")

def is_cache_valid():
if os.path.exists(cache_file):
cache_mtime = os.path.getmtime(cache_file)
current_time = time.time()
return (current_time - cache_mtime) < cache_duration
return False

def download_and_cache():
try:
print_greyed_out("[*] CISA_KEV_DATA_LOADER: Downloading CISA KEV catalog...")
response = httpx.get(
"https://www.cisa.gov/sites/default/files/feeds/known_exploited_vulnerabilities.json",
timeout=15
)
if response.status_code == 200:
data = response.json()
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
print_greyed_out("[+] CISA_KEV_DATA_LOADER: CISA KEV catalog downloaded and cached.")
return data
else:
print_greyed_out(f"[!] CISA_KEV_DATA_LOADER: Failed to fetch CISA KEV catalog. Status code: {response.status_code}")
except Exception as e:
print_greyed_out(f"[!] CISA_KEV_DATA_LOADER: Error fetching CISA KEV data: {e}")
return {}

if is_cache_valid():
try:
with open(cache_file, 'r', encoding='utf-8') as f:
data = json.load(f)
print_greyed_out("[+] CISA_KEV_DATA_LOADER: Loaded CISA KEV data from cache.")
return data
except Exception as e:
print_greyed_out(f"[!] CISA_KEV_DATA_LOADER: Error reading cache file '{cache_file}': {e}")
return download_and_cache()
else:
return download_and_cache()
Loading

0 comments on commit de848fe

Please sign in to comment.