Skip to content

Commit

Permalink
Implement premium account status for TTS rate improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
Rosuav committed Jan 10, 2025
1 parent fc25129 commit 8d4e35b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
16 changes: 10 additions & 6 deletions modules/http/chan_alertbox.pike
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,12 @@ __async__ mapping(string:mixed) http_request(Protocols.HTTP.Server.Request req)
"personals": cfg->personals || ({ }),
]));
}
mapping premium = await(G->G->DB->load_config(0, "premium_accounts"));
mapping prem = premium[(string)req->misc->channel->userid] || ([]);
return render(req, ([
"vars": (["ws_group": "control",
"maxfilesize": MAX_PER_FILE, "maxtotsize": MAX_TOTAL_STORAGE,
"avail_voices": tts_config->avail_voices[?RATE_STANDARD] || ({ }), //For now, only expose standard rate
"avail_voices": tts_config->avail_voices[?prem->tts_rate] || ({ }),
"follower_alert_scopes": req->misc->channel->name != "#!demo" && ensure_bcaster_token(req, "moderator:read:followers"),
]),
]) | req->misc->chaninfo);
Expand Down Expand Up @@ -1267,12 +1269,14 @@ __async__ string filter_bad_words(string text, string mode) {
return words * " ";
}

__async__ string|zero text_to_speech(string text, string voice, string origin) {
__async__ string|zero text_to_speech(string text, string voice, int|void userid) {
string token = tts_config->?access_token;
if (!token) return 0;
array v = voice / "/";
//TODO: Allow different whitelists for different origins
if (!tts_config->voices[RATE_STANDARD][v[1]]) return 0;
//Different whitelists for different userids (default to rate 0 aka Standard if not recognized)
mapping premium = await(G->G->DB->load_config(0, "premium_accounts"));
mapping prem = premium[(string)userid] || ([]);
if (!tts_config->voices[prem->tts_rate][v[1]]) return 0;
object reqargs = Protocols.HTTP.Promise.Arguments((["headers": ([
"Authorization": "Bearer " + token,
"Content-Type": "application/json; charset=utf-8",
Expand All @@ -1288,7 +1292,7 @@ __async__ string|zero text_to_speech(string text, string voice, string origin) {
System.Timer tm = System.Timer();
object res = await(Protocols.HTTP.Promise.post_url("https://texttospeech.googleapis.com/v1/text:synthesize", reqargs));
float delay = tm->get();
Stdio.append_file("tts_stats.log", sprintf("%s text %O fetch time %.3f\n", origin, text, delay));
Stdio.append_file("tts_stats.log", sprintf("User %d text %O fetch time %.3f\n", userid, text, delay));
mixed data; catch {data = Standards.JSON.decode_utf8(res->get());};
if (mappingp(data) && data->error->?details[?0]->?reason == "ACCESS_TOKEN_EXPIRED") {
Stdio.append_file("tts_error.log", sprintf("%sTTS access key expired after %d seconds\n",
Expand Down Expand Up @@ -1334,7 +1338,7 @@ __async__ void send_with_tts(object channel, mapping args, string|void destgroup
text += fmt;
string voice = inh->tts_voice || "";
if (sizeof(voice / "/") != 3) voice = tts_config->default_voice;
if (string tts = text != "" && await(text_to_speech(text, voice, sprintf("Channel %O", channel->name)))) args->tts = tts;
if (string tts = text != "" && await(text_to_speech(text, voice, channel->userid))) args->tts = tts;
send_updates_all((destgroup || cfg->authkey) + "#" + channel->userid, args);
}

Expand Down
16 changes: 12 additions & 4 deletions modules/http/tts_hack.pike
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@ constant markdown = #"# TTS hack
<form id=send><label>Enter stuff: <input size=80 id=stuff></label> <button>Send</button></form>
";

mapping(string:mixed) http_request(Protocols.HTTP.Server.Request req) {
return render(req, (["vars": (["ws_group": req->variables->key || ""])]));
//Map a group name to the last-sighted TTS rate schedule
//This won't be perfect, but it's only for the drop-down, so it's not that big a deal if it's wrong.
mapping tts_rate = ([]);
__async__ mapping(string:mixed) http_request(Protocols.HTTP.Server.Request req) {
string key = req->variables->key || "";
if (key != "") {
mapping premium = await(G->G->DB->load_config(0, "premium_accounts"));
tts_rate[key] = premium[(string)req->misc->session->user->?id]->?tts_rate;
}
return render(req, (["vars": (["ws_group": key])]));
}

@retain: multiset tts_hack_valid_keys = (<>);
Expand All @@ -19,7 +27,7 @@ string websocket_validate(mapping(string:mixed) conn, mapping(string:mixed) msg)
}

mapping get_state(string group) {
return (["voices": G->G->tts_config->avail_voices[?0] || ({ })]); //0 == RATE_STANDARD
return (["voices": G->G->tts_config->avail_voices[?tts_rate[group]] || ({ })]); //If no TTS rate set, use RATE_STANDARD (0)
}

__async__ mapping|zero websocket_cmd_speak(mapping(string:mixed) conn, mapping(string:mixed) msg) {
Expand All @@ -28,7 +36,7 @@ __async__ mapping|zero websocket_cmd_speak(mapping(string:mixed) conn, mapping(s
object alertbox = G->G->websocket_types->chan_alertbox;
text = await(alertbox->filter_bad_words(text, "replace"));
werror("TTS Hack %O -> %O\n", msg->voice, msg->text);
string tts = await(alertbox->text_to_speech(text, msg->voice || "en-GB/en-GB-Standard-A/FEMALE", "tts_hack"));
string tts = await(alertbox->text_to_speech(text, msg->voice || "en-GB/en-GB-Standard-A/FEMALE", (int)conn->session->user->?id));
return (["cmd": "speak", "text": text, "tts": tts]);
}

Expand Down

0 comments on commit 8d4e35b

Please sign in to comment.