This commit is contained in:
Roy Han 2024-08-09 11:04:26 -07:00
parent ad7e822883
commit 89f3bae306
7 changed files with 228 additions and 6 deletions

View File

@ -37,7 +37,7 @@ func (e StatusError) Error() string {
type ImageData []byte
type WhisperRequest struct {
Model string `json:"model"`
Model string `json:"model,omitempty"`
Audio string `json:"audio,omitempty"`
Transcribe bool `json:"transcribe,omitempty"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
@ -116,6 +116,8 @@ type ChatRequest struct {
Options map[string]interface{} `json:"options"`
Speech *WhisperRequest `json:"speech,omitempty"`
RunSpeech bool `json:"run_speech,omitempty"`
}
type Tools []Tool

View File

@ -38,6 +38,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/recorder"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@ -380,6 +381,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
speech, err := cmd.Flags().GetBool("speech")
if err != nil {
return err
}
if speech {
return generateInteractiveAudio(cmd, opts)
}
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
@ -862,6 +871,7 @@ type runOptions struct {
Options map[string]interface{}
MultiModal bool
KeepAlive *api.Duration
Audio bool
}
type displayResponseState struct {
@ -970,6 +980,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req.KeepAlive = opts.KeepAlive
}
if opts.Audio {
req.RunSpeech = true
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
@ -1055,6 +1069,30 @@ func generate(cmd *cobra.Command, opts runOptions) error {
KeepAlive: opts.KeepAlive,
}
speech, err := cmd.Flags().GetBool("speech")
if err != nil {
return err
}
// create temp wav file with the recorder package
if speech {
tempFile, err := os.CreateTemp("", "recording-*.wav")
if err != nil {
return err
}
defer os.Remove(tempFile.Name())
fmt.Print("Speech Mode\n\n")
err = recorder.RecordAudio(tempFile)
if err != nil {
return err
}
request.Speech = &api.WhisperRequest{
Audio: tempFile.Name(),
}
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
@ -1262,6 +1300,7 @@ func NewCLI() *cobra.Command {
RunE: RunHandler,
}
runCmd.Flags().Bool("speech", false, "Speech to text mode")
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")

View File

@ -20,6 +20,7 @@ import (
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/recorder"
"github.com/ollama/ollama/types/errtypes"
)
@ -51,6 +52,40 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
}
func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error {
for {
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// create temp wav file with the recorder package
tempFile, err := os.CreateTemp("", "recording-*.wav")
if err != nil {
return err
}
defer os.Remove(tempFile.Name())
err = recorder.RecordAudio(tempFile)
if err != nil {
return err
}
p.StopAndClear()
newMessage := api.Message{Role: "user", Audio: tempFile.Name()}
opts.Audio = true
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
}
}
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")

1
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

2
go.sum
View File

@ -115,6 +115,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=

137
recorder/recorder.go Normal file
View File

@ -0,0 +1,137 @@
package recorder
import (
"encoding/binary"
"fmt"
"os"
"os/signal"
"syscall"
"golang.org/x/sys/unix"
"golang.org/x/term"
"github.com/gordonklaus/portaudio"
)
const (
sampleRate = 16000
numChannels = 1
bitsPerSample = 16
)
func RecordAudio(f *os.File) error {
fmt.Print("Recording. Press any key to stop.\n\n")
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
portaudio.Initialize()
defer portaudio.Terminate()
in := make([]int16, 64)
stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in)
if err != nil {
return err
}
defer stream.Close()
err = stream.Start()
if err != nil {
return err
}
// Write WAV header with placeholder sizes
writeWavHeader(f, sampleRate, numChannels, bitsPerSample)
var totalSamples uint32
// Set up terminal input reading
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return err
}
defer term.Restore(int(os.Stdin.Fd()), oldState)
// Create a channel to handle the stop signal
stop := make(chan struct{})
go func() {
_, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1))
if err != nil {
fmt.Println("Error reading from stdin:", err)
return
}
// Send signal to stop recording
stop <- struct{}{}
}()
loop:
for {
err = stream.Read()
if err != nil {
return err
}
err = binary.Write(f, binary.LittleEndian, in)
if err != nil {
return err
}
totalSamples += uint32(len(in))
select {
case <-stop:
break loop
case <-sig:
break loop
default:
}
}
err = stream.Stop()
if err != nil {
return err
}
// Update WAV header with actual sizes
updateWavHeader(f, totalSamples, numChannels, bitsPerSample)
return nil
}
func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) {
subchunk1Size := 16
audioFormat := 1
byteRate := sampleRate * numChannels * (bitsPerSample / 8)
blockAlign := numChannels * (bitsPerSample / 8)
// Write the RIFF header
f.Write([]byte("RIFF"))
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size
f.Write([]byte("WAVE"))
// Write the fmt subchunk
f.Write([]byte("fmt "))
binary.Write(f, binary.LittleEndian, uint32(subchunk1Size))
binary.Write(f, binary.LittleEndian, uint16(audioFormat))
binary.Write(f, binary.LittleEndian, uint16(numChannels))
binary.Write(f, binary.LittleEndian, uint32(sampleRate))
binary.Write(f, binary.LittleEndian, uint32(byteRate))
binary.Write(f, binary.LittleEndian, uint16(blockAlign))
binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
// Write the data subchunk header
f.Write([]byte("data"))
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size
}
func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) {
fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8))
dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)
// Seek to the start of the file and write updated sizes
f.Seek(4, 0)
binary.Write(f, binary.LittleEndian, uint32(fileSize))
f.Seek(40, 0)
binary.Write(f, binary.LittleEndian, uint32(dataSize))
}

View File

@ -110,7 +110,12 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
}
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
modelPath := speech.Model
var modelPath string
if speech.Model == "" {
modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin"
} else {
modelPath = speech.Model
}
// default to 5 minutes
var sessionDuration time.Duration
@ -130,7 +135,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan er
return
}
whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server"
whisperServer := "/Users/royhan-ollama/.ollama/server"
// Find an available port for whisper
port := 0
@ -1510,8 +1515,9 @@ func (s *Server) ProcessHandler(c *gin.Context) {
}
func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error {
if req.Model == "" {
return nil
slog.Info("processing audio")
if req == nil {
req = &api.WhisperRequest{}
}
portCh := make(chan int, 1)
errCh := make(chan error, 1)
@ -1583,7 +1589,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
if req.Speech != nil {
if req.Speech != nil || req.RunSpeech {
if err := processAudio(c, s, msgs, req.Speech); err != nil {
slog.Error("failed to process audio", "error", err)
return