Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ RUN export DEBIAN_FRONTEND=noninteractive \
python${PYTHON_VERSION}-venv \
python3-pip \
libcudnn8 \
libcudnn8-dev \
# Make sure to install all required libcudnn components
libcudnn8-samples \
python3-pip \
&& rm -rf /var/lib/apt/lists/*

Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ docker run -d -p 9000:9000 \

- Multiple ASR engines support (OpenAI Whisper, Faster Whisper, WhisperX)
- Multiple output formats (text, JSON, VTT, SRT, TSV)
- Support for outputting all formats simultaneously with a single request
- Word-level timestamps support
- Voice activity detection (VAD) filtering
- Speaker diarization (with WhisperX)
Expand Down Expand Up @@ -90,3 +91,5 @@ After starting the service, visit `http://localhost:9000` or `http://0.0.0.0:900
## Credits

- This software uses libraries from the [FFmpeg](http://ffmpeg.org) project under the [LGPLv2.1](http://www.gnu.org/licenses/old-licenses/lgpl-2.1.html)


42 changes: 34 additions & 8 deletions app/asr_models/faster_whisper_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import os
from io import StringIO
from threading import Thread
from typing import BinaryIO, Union
Expand All @@ -8,7 +9,7 @@

from app.asr_models.asr_model import ASRModel
from app.config import CONFIG
from app.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
from app.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT, WriteAll


class FasterWhisperASR(ASRModel):
Expand Down Expand Up @@ -59,10 +60,23 @@ def transcribe(
text = text + segment.text
result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text}

# Store the output directory and audio path for the "all" option
self.output_dir = os.environ.get("OUTPUT_DIR", "/tmp")
self.audio_path = os.environ.get("AUDIO_FILENAME", "audio")

# For "all" output format, create and return the zip bytes
if output == "all":
writer = WriteAll(self.output_dir)
zip_bytes = writer.create_zip_bytes(result)
# Create a generator that yields the bytes
def bytes_generator():
yield zip_bytes
return bytes_generator()

# For other formats, write to StringIO and return that
output_file = StringIO()
self.write_result(result, output_file, output)
output_file.seek(0)

return output_file

def language_detection(self, audio):
Expand All @@ -84,13 +98,25 @@ def language_detection(self, audio):
return detected_lang_code, detected_language_confidence

def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
"""
Write the transcription result to the specified output format.

For 'all' format, this function is not directly used as the transcribe method
handles it with create_zip_bytes.
For other formats, writes directly to the provided file object.
"""
# Initialize the appropriate writer class based on the output format
if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file)
writer_class = WriteSRT
elif output == "vtt":
WriteVTT(ResultWriter).write_result(result, file=file)
writer_class = WriteVTT
elif output == "tsv":
WriteTSV(ResultWriter).write_result(result, file=file)
writer_class = WriteTSV
elif output == "json":
WriteJSON(ResultWriter).write_result(result, file=file)
else:
WriteTXT(ResultWriter).write_result(result, file=file)
writer_class = WriteJSON
else: # Default to txt
writer_class = WriteTXT

# Create a ResultWriter instance and write to the file
writer = writer_class(self.output_dir)
writer.write_result(result, file=file)
18 changes: 17 additions & 1 deletion app/asr_models/mbain_whisperx_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import os
from io import StringIO
from threading import Thread
from typing import BinaryIO, Union
Expand All @@ -9,6 +10,7 @@

from app.asr_models.asr_model import ASRModel
from app.config import CONFIG
from app.utils import WriteAll


class WhisperXASR(ASRModel):
Expand Down Expand Up @@ -85,10 +87,24 @@ def transcribe(
result = whisperx.assign_word_speakers(diarize_segments, result)
result["language"] = language

# Store the output directory and audio path for the "all" option
self.output_dir = os.environ.get("OUTPUT_DIR", "/tmp")
self.audio_path = os.environ.get("AUDIO_FILENAME", "audio")

# For "all" output format, create and return the zip bytes
if output == "all":
# Import WriteAll from app.utils if needed
writer = WriteAll(self.output_dir)
zip_bytes = writer.create_zip_bytes(result)
# Create a generator that yields the bytes
def bytes_generator():
yield zip_bytes
return bytes_generator()

# For other formats, write to StringIO and return that
output_file = StringIO()
self.write_result(result, output_file, output)
output_file.seek(0)

return output_file

def language_detection(self, audio):
Expand Down
18 changes: 17 additions & 1 deletion app/asr_models/openai_whisper_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import os
from io import StringIO
from threading import Thread
from typing import BinaryIO, Union
Expand All @@ -9,6 +10,7 @@

from app.asr_models.asr_model import ASRModel
from app.config import CONFIG
from app.utils import WriteAll


class OpenAIWhisperASR(ASRModel):
Expand Down Expand Up @@ -49,10 +51,24 @@ def transcribe(
with self.model_lock:
result = self.model.transcribe(audio, **options_dict)

# Store the output directory and audio path for the "all" option
self.output_dir = os.environ.get("OUTPUT_DIR", "/tmp")
self.audio_path = os.environ.get("AUDIO_FILENAME", "audio")

# For "all" output format, create and return the zip bytes
if output == "all":
# Import WriteAll from app.utils if needed
writer = WriteAll(self.output_dir)
zip_bytes = writer.create_zip_bytes(result)
# Create a generator that yields the bytes
def bytes_generator():
yield zip_bytes
return bytes_generator()

# For other formats, write to StringIO and return that
output_file = StringIO()
self.write_result(result, output_file, output)
output_file.seek(0)

return output_file

def language_detection(self, audio):
Expand Down
126 changes: 116 additions & 10 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import io
import zipfile
from dataclasses import asdict
from typing import BinaryIO, TextIO

Expand Down Expand Up @@ -32,7 +34,9 @@ class WriteTXT(ResultWriter):

def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment.text.strip(), file=file, flush=True)
# Handle both segment as dict and as object
text = segment["text"] if isinstance(segment, dict) else segment.text
print(text.strip(), file=file, flush=True)


class WriteVTT(ResultWriter):
Expand All @@ -41,9 +45,19 @@ class WriteVTT(ResultWriter):
def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file)
for segment in result["segments"]:
# Handle both segment as dict and as object
if isinstance(segment, dict):
start = segment["start"]
end = segment["end"]
text = segment["text"]
else:
start = segment.start
end = segment.end
text = segment.text

print(
f"{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}\n"
f"{segment.text.strip().replace('-->', '->')}\n",
f"{format_timestamp(start)} --> {format_timestamp(end)}\n"
f"{text.strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
Expand All @@ -54,12 +68,22 @@ class WriteSRT(ResultWriter):

def write_result(self, result: dict, file: TextIO):
for i, segment in enumerate(result["segments"], start=1):
# Handle both segment as dict and as object
if isinstance(segment, dict):
start = segment["start"]
end = segment["end"]
text = segment["text"]
else:
start = segment.start
end = segment.end
text = segment.text

# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment.start, always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment.end, always_include_hours=True, decimal_marker=',')}\n"
f"{segment.text.strip().replace('-->', '->')}\n",
f"{format_timestamp(start, always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(end, always_include_hours=True, decimal_marker=',')}\n"
f"{text.strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
Expand All @@ -80,20 +104,102 @@ class WriteTSV(ResultWriter):
def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment.start), file=file, end="\t")
print(round(1000 * segment.end), file=file, end="\t")
print(segment.text.strip().replace("\t", " "), file=file, flush=True)
# Handle both segment as dict and as object
if isinstance(segment, dict):
start = segment["start"]
end = segment["end"]
text = segment["text"]
else:
start = segment.start
end = segment.end
text = segment.text

print(round(1000 * start), file=file, end="\t")
print(round(1000 * end), file=file, end="\t")
print(text.strip().replace("\t", " "), file=file, flush=True)


class WriteJSON(ResultWriter):
extension: str = "json"

def write_result(self, result: dict, file: TextIO):
if "segments" in result:
result["segments"] = [asdict(segment) for segment in result["segments"]]
# Check if segments are already dictionaries or need to be converted
if result["segments"] and not isinstance(result["segments"][0], dict):
result["segments"] = [asdict(segment) for segment in result["segments"]]
json.dump(result, file)


class WriteAll:
"""
Write a transcript to multiple files in all supported formats.
"""

def __init__(self, output_dir: str):
self.output_dir = output_dir
self.writers = {
"txt": WriteTXT(output_dir),
"vtt": WriteVTT(output_dir),
"srt": WriteSRT(output_dir),
"tsv": WriteTSV(output_dir),
"json": WriteJSON(output_dir)
}

def __call__(self, result: dict, audio_path: str):
for format_name, writer in self.writers.items():
try:
writer(result, audio_path)
except Exception as e:
print(f"Error in {format_name} writer: {str(e)}")
# Continue with other formats even if one fails

def create_zip_bytes(self, result: dict):
"""
Create a zip file in memory and return its bytes.
This creates a valid zip file with all transcript formats.
"""
# Create a new in-memory zip file
buffer = io.BytesIO()

try:
# Open the zip file for writing
with zipfile.ZipFile(buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# Write each format to the zip file
formats = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON
}

for format_name, writer_class in formats.items():
try:
# Create a buffer for this format's content
output = io.StringIO()

# Write the result to the buffer
writer = writer_class(self.output_dir)
writer.write_result(result, output)

# Get the text content and add it to the zip
content = output.getvalue().encode('utf-8') # Convert string to bytes
zip_file.writestr(f"transcript.{format_name}", content)

except Exception as e:
print(f"Error adding {format_name} to zip: {str(e)}")
# Continue with other formats

# Reset the buffer position and get the zip bytes
buffer.seek(0)
return buffer.read()

except Exception as e:
print(f"Error creating zip file: {str(e)}")
# Return an empty buffer if zip creation fails
return b""


def load_audio(file: BinaryIO, encode=True, sr: int = CONFIG.SAMPLE_RATE):
"""
Open an audio file object and read as mono waveform, resampling as necessary.
Expand Down
Loading