-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlocal_chatbot.py
More file actions
177 lines (154 loc) · 6.99 KB
/
local_chatbot.py
File metadata and controls
177 lines (154 loc) · 6.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import sys
import requests
import argparse
import riva.client
import riva.client.audio_io
class ArgsASR():
def __init__(self) -> None:
self.input_device = 25
self.list_devices = False
self.profanity_filter = False
self.automatic_punctuation = False
self.no_verbatim_transcripts = False
self.language_code = "en-US"
self.boosted_lm_words = None
self.boosted_lm_score = 4.0
self.speaker_diarization = False
self.server = "localhost:50051"
self.ssl_cert = None
self.use_ssl = False
self.metadata = None
self.sample_rate_hz = 16000
self.file_streaming_chunk = 1600
class ArgsTTS():
def __init__(self) -> None:
self.language_code = 'en-US'
self.sample_rate_hz = 48000
self.stream = True
self.output_device = 30
class ChatBot():
def __init__(self, args) -> None:
if args.list_input_devices:
riva.client.audio_io.list_input_devices()
sys.exit()
if args.list_output_devices:
riva.client.audio_io.list_output_devices()
sys.exit()
self.args_asr = ArgsASR()
self.args_tts = ArgsTTS()
self.args_asr.input_device = args.input_device
self.args_tts.output_device = args.output_device
auth = riva.client.Auth(uri=self.args_asr.server)
self.asr_service = riva.client.ASRService(auth)
self.tts_service = riva.client.SpeechSynthesisService(auth)
self.config_asr = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self.args_asr.language_code,
max_alternatives=1,
profanity_filter=self.args_asr.profanity_filter,
enable_automatic_punctuation=self.args_asr.automatic_punctuation,
verbatim_transcripts=not self.args_asr.no_verbatim_transcripts,
sample_rate_hertz=self.args_asr.sample_rate_hz,
audio_channel_count=1,
),
interim_results=True,
)
riva.client.add_word_boosting_to_config(
self.config_asr,
self.args_asr.boosted_lm_words,
self.args_asr.boosted_lm_score
)
self.flag_wakeup = False
def output_audio(self, answer):
sound_stream = None
try:
if self.args_tts.output_device is not None:
# For playing audio during synthesis you will need to pass audio chunks to riva.
# client.audio_io.SoundCallBack as they arrive.
sound_stream = riva.client.audio_io.SoundCallBack(
self.args_tts.output_device, nchannels=1, sampwidth=2,
framerate=self.args_tts.sample_rate_hz
)
if self.args_tts.stream:
responses1 = self.tts_service.synthesize_online(
answer, None, self.args_tts.language_code,
sample_rate_hz=self.args_tts.sample_rate_hz
)
for resp in responses1:
if sound_stream is not None:
sound_stream(resp.audio)
finally:
if sound_stream is not None:
sound_stream.close()
def run(self):
output_asr = ""
while True:
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
with riva.client.audio_io.MicrophoneStream(
self.args_asr.sample_rate_hz,
self.args_asr.file_streaming_chunk,
device=self.args_asr.input_device,
) as stream_mic:
print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
try:
for response in self.asr_service.streaming_response_generator(
audio_chunks=stream_mic,
streaming_config=self.config_asr,
): # Here is a continuous listening loop
for result in response.results:
if result.is_final:
transcripts = result.alternatives[0].transcript
output_asr = transcripts
if output_asr != "":
if "hello" in output_asr:
self.flag_wakeup = True
self.output_audio('Here!')
output_asr = ""
stream_mic.close()
if "stop" in output_asr and self.flag_wakeup:
self.flag_wakeup = False
self.output_audio('Bye! Have a great day!')
output_asr = ""
stream_mic.close()
if self.flag_wakeup and self.isinstance(output_asr):
print(f'User Input: >>>>>\n {output_asr} \n')
stream_mic.close()
headers = {"Content-Type": "application/json",}
data = {
'inputs': output_asr,
'parameters': {
'max_new_tokens': 50,
},
}
response = requests.post(
'http://192.168.49.74:8899/generate', headers=headers, json=data
)
result = response.json()['generated_text']
print(f'ChatBot Output: >>>>>\n {result} \n')
self.output_audio(result)
output_asr = ""
finally:
pass
@staticmethod
def isinstance(text, min_length=5, min_unique_words=1):
text = text.strip()
if len(text) < min_length:
return False
words = text.split()
unique_words = set(words)
if len(unique_words) < min_unique_words:
return False
return True
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="",)
parser.add_argument("--list-input-devices", action="store_true", help="List input audio device indices.")
parser.add_argument("--list-output-devices", action="store_true", help="List input audio device indices.")
parser.add_argument("--input-device", type=int, default=25, help="Set input audio device.")
parser.add_argument("--output-device", type=int, default=30, help="Set output audio device.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
chatbot = ChatBot(args)
chatbot.run()