Skip to content

Commit

Permalink
Merge pull request #22 from statelyai/davidkpiano/simplify-event-choice
Browse files Browse the repository at this point in the history
Simplify event choices
  • Loading branch information
davidkpiano authored Mar 16, 2024
2 parents 3e1d02b + 8a2c34b commit 1424f03
Show file tree
Hide file tree
Showing 13 changed files with 265 additions and 264 deletions.
28 changes: 28 additions & 0 deletions .changeset/soft-readers-attend.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
---
'@statelyai/agent': patch
---

The `createSchemas(…)` function has been removed. The `defineEvents(…)` function should be used instead, as it is a simpler way of defining events and event schemas using Zod:

```ts
import { defineEvents } from '@statelyai/agent';
import { z } from 'zod';
import { setup } from 'xstate';

const events = defineEvents({
inc: z.object({
by: z.number().describe('Increment amount'),
}),
});

const machine = setup({
types: {
events: events.types,
},
schema: {
events: events.schemas,
},
}).createMachine({
// ...
});
```
117 changes: 55 additions & 62 deletions examples/joke.ts
Original file line number Diff line number Diff line change
@@ -1,53 +1,37 @@
import OpenAI from 'openai';
import { assign, fromCallback, fromPromise, log, setup } from 'xstate';
import { createAgent, createOpenAIAdapter, createSchemas } from '../src';
import { createAgent, createOpenAIAdapter, defineEvents } from '../src';
import { loadingAnimation } from './helpers/loader';
import { z } from 'zod';

const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

const schemas = createSchemas({
context: {
type: 'object',
properties: {
topic: { type: 'string' },
jokes: {
type: 'array',
items: {
type: 'string',
},
},
desire: { type: ['string', 'null'] },
lastRating: { type: ['string', 'null'] },
},
required: ['topic', 'jokes', 'desire', 'lastRating'],
},
events: {
askForTopic: {
type: 'object',
properties: {
topic: {
type: 'string',
},
},
},
endJokes: {
type: 'object',
properties: {},
},
},
const events = defineEvents({
askForTopic: z.object({
topic: z.string().describe('The topic for the joke'),
}),
tellJoke: z.object({
joke: z.string().describe('The joke text'),
}),
endJokes: z.object({}).describe('End the jokes'),

rateJoke: z.object({
rating: z.number().min(1).max(10),
explanation: z.string(),
}),
});

const adapter = createOpenAIAdapter(openai, {
model: 'gpt-3.5-turbo-1106',
});

const getJokeCompletion = adapter.fromChat(
const getJokeCompletion = adapter.fromEvent(
(topic: string) => `Tell me a joke about ${topic}.`
);

const rateJoke = adapter.fromChat(
const rateJoke = adapter.fromEvent(
(joke: string) => `Rate this joke on a scale of 1 to 10: ${joke}`
);

Expand All @@ -66,7 +50,7 @@ const getTopic = fromPromise(async () => {
});

const decide = adapter.fromEvent(
(lastRating: string) =>
(lastRating: number) =>
`Choose what to do next, given the previous rating of the joke: ${lastRating}`
);
export function getRandomFunnyPhrase() {
Expand Down Expand Up @@ -109,8 +93,19 @@ const loader = fromCallback(({ input }: { input: string }) => {
});

const jokeMachine = setup({
schemas,
types: schemas.types,
schemas: {
events: events.schemas,
},
types: {
context: {} as {
topic: string;
jokes: string[];
desire: string | null;
lastRating: number | null;
loader: string | null;
},
events: events.types,
},
actors: {
getJokeCompletion,
getTopic,
Expand All @@ -119,6 +114,7 @@ const jokeMachine = setup({
loader,
},
}).createMachine({
id: 'joke',
context: () => ({
topic: '',
jokes: [],
Expand All @@ -144,54 +140,45 @@ const jokeMachine = setup({
{
src: 'getJokeCompletion',
input: ({ context }) => context.topic,
onDone: {
actions: [
assign({
jokes: ({ context, event }) =>
context.jokes.concat(
event.output.choices[0]!.message.content!
),
}),
log((x) => `\n` + x.context.jokes.at(-1)),
],
target: 'rateJoke',
},
},
{
src: 'loader',
input: getRandomFunnyPhrase,
},
],
on: {
tellJoke: {
actions: assign({
jokes: ({ context, event }) => [...context.jokes, event.joke],
}),
target: 'rateJoke',
},
},
},
rateJoke: {
invoke: [
{
src: 'rateJoke',
input: ({ context }) => context.jokes[context.jokes.length - 1]!,
onDone: {
actions: [
assign({
lastRating: ({ event }) =>
event.output.choices[0]!.message.content!,
}),
log(({ context }) => '\n' + context.lastRating),
],
target: 'decide',
},
},
{
src: 'loader',
input: getRandomRatingPhrase,
},
],
on: {
rateJoke: {
actions: assign({
lastRating: ({ event }) => event.rating,
}),
target: 'decide',
},
},
},
decide: {
invoke: {
src: 'decide',
input: ({ context }) => context.lastRating!,
onDone: {
actions: log(({ event }) => event),
},
},
on: {
askForTopic: {
Expand All @@ -216,5 +203,11 @@ const jokeMachine = setup({
},
});

const agent = createAgent(jokeMachine);
const agent = createAgent(jokeMachine, {
inspect: (ev) => {
if (ev.type === '@xstate.event') {
console.log(`\n${ev.actorRef.id}`, ev.event);
}
},
});
agent.start();
47 changes: 10 additions & 37 deletions examples/numberGuesser.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import OpenAI from 'openai';
import { createAgent, createOpenAIAdapter, createSchemas } from '../src';
import { createAgent, createOpenAIAdapter, defineEvents } from '../src';
import { assign, setup } from 'xstate';
import { z } from 'zod';
const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});
Expand All @@ -23,40 +24,10 @@ const guessLogic = adapter.fromEvent(
`
);

const schemas = createSchemas({
context: {
type: 'object',
properties: {
lastGuess: {
type: ['number', 'null'],
description: 'The last guess',
},
previousGuesses: {
type: 'array',
items: {
type: 'number',
},
description: 'The previous guesses',
},
answer: {
type: 'number',
description: 'The answer',
},
},
},
events: {
guess: {
properties: {
number: {
// integer
type: 'number',
minimum: 1,
maximum: 10,
},
},
required: ['number'],
},
},
const events = defineEvents({
guess: z.object({
number: z.number().min(1).max(10).describe('The number guessed'),
}),
});

const machine = setup({
Expand All @@ -66,9 +37,11 @@ const machine = setup({
answer: number;
},
input: {} as { answer: number },
events: schemas.types.events,
events: events.types,
},
schemas: {
events: events.schemas,
},
schemas,
actors: {
guessLogic,
},
Expand Down
Loading

0 comments on commit 1424f03

Please sign in to comment.