Skip to content

Commit

Permalink
Ak/automation copilot (#50)
Browse files Browse the repository at this point in the history
* try to run in gha

* make ruff happy

* Update integration-local.yaml

* add playwright deps to addon test too
  • Loading branch information
akshaya-a authored Apr 22, 2024
1 parent aada470 commit d3217f2
Show file tree
Hide file tree
Showing 19 changed files with 792 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/integration-addon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ jobs:
- name: Lint with Ruff
run: |
ruff check .
- name: install playwright deps
run: |
playwright install --with-deps chromium
- name: Copy python source + services into build context because builder action doesn't support --build-context
id: copy
run: bash scripts/copy_content_to_addon_context.sh
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/integration-local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ jobs:
- name: Lint with Ruff
run: |
ruff check .
- name: install playwright deps
run: |
playwright install --with-deps chromium
- name: Test with pytest
run: |
pytest -v -s -rA -c tests/pytest.ini --deploy-mode=local --replay-mode=replay
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# LLMOps for Home Assistant
# Home ~~Automation~~ Intelligence

Manage your home automation LLM prompts, available LLMs and evaluate changes with MLflow

Expand Down
Empty file.
229 changes: 229 additions & 0 deletions python/src/mindctrl/homeassistant/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import asyncio
import logging
from typing import Union
import time
import httpx
from httpx_ws import aconnect_ws
from pydantic import ValidationError

from .messages import (
AreasResult,
Auth,
AuthChallenge,
AuthOk,
Automation,
Command,
CreateAutomation,
CreateLabel,
Error,
Label,
LabelsResult,
ListAreas,
ListEntities,
ListLabels,
ManyResponsesWrapper,
Result,
SingleResponseWrapper,
UpdateEntityLabels,
)

_logger = logging.getLogger(__name__)


class HassClientError(Exception):
pass


class HassClient(object):
def __init__(self, id: str, hass_url: httpx.URL, token: str):
self.hass_url = hass_url
self._token = token
if not (str(self.hass_url).endswith("api")):
raise ValueError(
f"hass_url must end with 'api'. For example, 'http://homeassistant.local:8123/api'. Got: {self.hass_url}"
)
self.headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
self.websocket_message_id = 0
self._client = httpx.AsyncClient(base_url=self.hass_url, headers=self.headers)
self._ws_session = aconnect_ws(
f"{self._client.base_url}websocket", client=self._client
)
self._authenticated_session = None

# is there a better way to process these?
self._command_results: dict[int, Union[Result, Error]] = {}

@property
def authenticated_session(self):
if not self._authenticated_session:
raise ValueError("Session not authenticated")
return self._authenticated_session

async def __aenter__(self):
session = await self._ws_session.__aenter__()

auth_required_msg = AuthChallenge.model_validate(await session.receive_json())
_logger.info(auth_required_msg)

await session.send_json(Auth(access_token=self._token).model_dump())

auth_ok_msg = AuthOk.model_validate(await session.receive_json())
_logger.info(auth_ok_msg)

self._authenticated_session = session
return self

async def __aexit__(self, exc_type, exc, tb):
if len(self._command_results.keys()) > 0:
_logger.warning(
f"Exiting with {len(self._command_results.keys())} unprocessed messages"
)
await self._ws_session.__aexit__(exc_type, exc, tb)
await self._client.aclose()

async def _receive_message(self, message_id: int) -> Union[Result, None]:
if message_id in self._command_results.keys():
val = self._command_results.pop(message_id)
if not val.success:
assert isinstance(val, Error)
raise HassClientError(f"Error: {val.code} - {val.message}")
assert isinstance(val, Result)
return val

response_json = await self.authenticated_session.receive_json()
responses: list[Union[Result, Error]]
try:
response = SingleResponseWrapper.model_validate(
{"response": response_json}, strict=False
)
responses = [response.response]
except ValidationError as ve:
_logger.info(f"receive: {ve}")
response = ManyResponsesWrapper.model_validate(
{"responses": response_json}, strict=False
)
responses = response.responses

return_response: Union[Result, Error, None] = None
for response in responses:
if response.id != message_id:
self._command_results[response.id] = response
else:
return_response = response

if isinstance(return_response, Error):
raise HassClientError(
f"Error: {return_response.code} - {return_response.message}"
)

return return_response

async def _send_message(self, message: Command) -> Result:
# The atomicity of this is sketchy - add tests
self.websocket_message_id += 1
message.id = self.websocket_message_id
await self.authenticated_session.send_json(message.model_dump())
result = None
while result is None:
result = await self._receive_message(message.id)
if result is None:
_logger.info("Recent receive batch didn't have the response")
await asyncio.sleep(0.1)
return result

@staticmethod
def _current_milli_time():
return round(time.time() * 1000)

async def list_entities(self):
entities = await self._send_message(ListEntities(id=-1))
if entities.result is None:
_logger.warning(f"Unexpected null entities result: {entities}")
return []
return entities.result

# TODO: bad api call pattern, need to revisit
async def list_automations(self) -> list[Automation]:
entities = await self.list_entities()
_logger.debug(entities)
automation_entities = [
entity
for entity in entities
if entity["platform"] == "automation"
]

_logger.info(f"Fetching {len(automation_entities)} automations")
fetch_automation_tasks = []
for entity in automation_entities:
fetch_automation_tasks.append(self.get_automation(entity["unique_id"]))
return await asyncio.gather(*fetch_automation_tasks)

async def list_labels(self):
any_result = await self._send_message(ListLabels(id=-1))
labels = LabelsResult.model_validate_json(any_result.model_dump_json())
return labels.result

async def list_areas(self):
any_result = await self._send_message(ListAreas(id=-1))
areas = AreasResult.model_validate_json(any_result.model_dump_json())
return areas.result

async def create_label(self, label: Label):
await self._send_message(
CreateLabel(
color=label.color,
description=label.description,
icon=label.icon,
name=label.name,
id=-1,
)
)
return

async def add_labels(self, entity_id: str, labels: list[str]):
await self._send_message(
UpdateEntityLabels(
entity_id=entity_id,
labels=labels,
id=-1,
)
)

async def get_automation(self, unique_id: str):
# Creating an automation is REST POST to a unix timestamp
# http://hass-dev.ak:8123/api/config/automation/config/1713577351529
# Request Method:
# POST
# Status Code:
# 200 OK
path = f"config/automation/config/{unique_id}"
get_response = await self._client.get(path)
get_response.raise_for_status()
return Automation.model_validate(get_response.json())

async def create_automation(self, name: str, description: str):
# Creating an automation is REST POST to a unix timestamp
# http://hass-dev.ak:8123/api/config/automation/config/1713577351529
# Request Method:
# POST
# Status Code:
# 200 OK
current_milli_time = HassClient._current_milli_time()
path = f"config/automation/config/{current_milli_time}"
response = await self._client.post(
path,
json=CreateAutomation(
alias=name,
description=description,
mode="single",
action=[],
condition=[],
trigger=[],
).model_dump(),
)
response.raise_for_status()
# response.json() is just {'result': 'ok'}, need to do a get (why?)
return await self.get_automation(str(current_milli_time))
117 changes: 117 additions & 0 deletions python/src/mindctrl/homeassistant/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Any, Optional, Union
from pydantic import BaseModel


class Message(BaseModel):
type: str


class AuthChallenge(Message):
type: str = "auth_required"
ha_version: str


class AuthOk(Message):
type: str = "auth_ok"
ha_version: str


class Auth(Message):
type: str = "auth"
access_token: str


class Command(Message):
id: int


class CommandResponse(Command):
type: str = "result"
success: bool


class Error(CommandResponse):
success: bool = False
code: str
message: str


class Result(CommandResponse):
success: bool = True
result: Optional[Any]


class ManyResponsesWrapper(BaseModel):
responses: list[Union[Error, Result]]


class SingleResponseWrapper(BaseModel):
response: Union[Error, Result]


class ListEntities(Command):
type: str = "config/entity_registry/list"


class ListLabels(Command):
type: str = "config/label_registry/list"


class ListAreas(Command):
type: str = "config/area_registry/list"


# {"color":"indigo","description":null,"icon":"mdi:account","label_id":"test","name":"test"}
class Label(BaseModel):
color: str
description: Optional[str]
icon: str
label_id: str
name: str


class LabelsResult(Result):
result: list[Label]


class Area(BaseModel):
area_id: str
name: str
aliases: list[str]
floor_id: Optional[str]
icon: Optional[str]
labels: list[str]
picture: Optional[str]


class AreasResult(Result):
result: list[Area]


# {"type":"config/label_registry/create","name":"test","icon":"mdi:account","color":"indigo","id":62}
class CreateLabel(Command):
type: str = "config/label_registry/create"
description: Optional[str]
name: str
icon: str
color: str


# {"type":"config/entity_registry/update","entity_id":"automation.test_zone_automation","labels":["test"],"id":51}
class UpdateEntityLabels(Command):
type: str = "config/entity_registry/update"
entity_id: str
labels: list[str]


class CreateAutomation(BaseModel):
action: list[Any]
alias: str
condition: list[Any]
description: str
mode: str
trigger: list[Any]


class Automation(CreateAutomation):
id: str
Loading

0 comments on commit d3217f2

Please sign in to comment.