From 89f3bae306e6c64afd2be2020a7ef47d7826470e Mon Sep 17 00:00:00 2001 From: Roy Han Date: Fri, 9 Aug 2024 11:04:26 -0700 Subject: [PATCH] cli --- api/types.go | 4 +- cmd/cmd.go | 39 ++++++++++++ cmd/interactive.go | 35 +++++++++++ go.mod | 1 + go.sum | 2 + recorder/recorder.go | 137 +++++++++++++++++++++++++++++++++++++++++++ server/routes.go | 16 +++-- 7 files changed, 228 insertions(+), 6 deletions(-) create mode 100644 recorder/recorder.go diff --git a/api/types.go b/api/types.go index 3cad38ed..82aa3ea0 100644 --- a/api/types.go +++ b/api/types.go @@ -37,7 +37,7 @@ func (e StatusError) Error() string { type ImageData []byte type WhisperRequest struct { - Model string `json:"model"` + Model string `json:"model,omitempty"` Audio string `json:"audio,omitempty"` Transcribe bool `json:"transcribe,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"` @@ -116,6 +116,8 @@ type ChatRequest struct { Options map[string]interface{} `json:"options"` Speech *WhisperRequest `json:"speech,omitempty"` + + RunSpeech bool `json:"run_speech,omitempty"` } type Tools []Tool diff --git a/cmd/cmd.go b/cmd/cmd.go index d47db65b..409dbc3f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -38,6 +38,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/recorder" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" @@ -380,6 +381,14 @@ func RunHandler(cmd *cobra.Command, args []string) error { } } + speech, err := cmd.Flags().GetBool("speech") + if err != nil { + return err + } + + if speech { + return generateInteractiveAudio(cmd, opts) + } return generateInteractive(cmd, opts) } return generate(cmd, opts) @@ -862,6 +871,7 @@ type runOptions struct { Options map[string]interface{} MultiModal bool KeepAlive *api.Duration + Audio bool } type displayResponseState struct { @@ -970,6 +980,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { req.KeepAlive = opts.KeepAlive } + if opts.Audio { + req.RunSpeech = true + } + if err := client.Chat(cancelCtx, req, fn); err != nil { if errors.Is(err, context.Canceled) { return nil, nil @@ -1055,6 +1069,30 @@ func generate(cmd *cobra.Command, opts runOptions) error { KeepAlive: opts.KeepAlive, } + speech, err := cmd.Flags().GetBool("speech") + if err != nil { + return err + } + + // create temp wav file with the recorder package + if speech { + tempFile, err := os.CreateTemp("", "recording-*.wav") + if err != nil { + return err + } + defer os.Remove(tempFile.Name()) + + fmt.Print("Speech Mode\n\n") + + err = recorder.RecordAudio(tempFile) + if err != nil { + return err + } + + request.Speech = &api.WhisperRequest{ + Audio: tempFile.Name(), + } + } if err := client.Generate(ctx, &request, fn); err != nil { if errors.Is(err, context.Canceled) { return nil @@ -1262,6 +1300,7 @@ func NewCLI() *cobra.Command { RunE: RunHandler, } + runCmd.Flags().Bool("speech", false, "Speech to text mode") runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)") runCmd.Flags().Bool("verbose", false, "Show timings for response") runCmd.Flags().Bool("insecure", false, "Use an insecure registry") diff --git a/cmd/interactive.go b/cmd/interactive.go index 4462cf29..f8f25178 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -20,6 +20,7 @@ import ( "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" + "github.com/ollama/ollama/recorder" "github.com/ollama/ollama/types/errtypes" ) @@ -51,6 +52,40 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error { return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil }) } +func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error { + for { + p := progress.NewProgress(os.Stderr) + spinner := progress.NewSpinner("") + p.Add("", spinner) + + // create temp wav file with the recorder package + tempFile, err := os.CreateTemp("", "recording-*.wav") + if err != nil { + return err + } + defer os.Remove(tempFile.Name()) + + err = recorder.RecordAudio(tempFile) + if err != nil { + return err + } + + p.StopAndClear() + + newMessage := api.Message{Role: "user", Audio: tempFile.Name()} + opts.Audio = true + opts.Messages = append(opts.Messages, newMessage) + + assistant, err := chat(cmd, opts) + if err != nil { + return err + } + if assistant != nil { + opts.Messages = append(opts.Messages, *assistant) + } + } +} + func generateInteractive(cmd *cobra.Command, opts runOptions) error { usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") diff --git a/go.mod b/go.mod index 2e0c6614..8ceb12de 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/agnivade/levenshtein v1.1.1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/google/go-cmp v0.6.0 + github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c diff --git a/go.sum b/go.sum index 926ed26d..3f58e163 100644 --- a/go.sum +++ b/go.sum @@ -115,6 +115,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc= +github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= diff --git a/recorder/recorder.go b/recorder/recorder.go new file mode 100644 index 00000000..71db1e13 --- /dev/null +++ b/recorder/recorder.go @@ -0,0 +1,137 @@ +package recorder + +import ( + "encoding/binary" + "fmt" + "os" + "os/signal" + "syscall" + + "golang.org/x/sys/unix" + "golang.org/x/term" + + "github.com/gordonklaus/portaudio" +) + +const ( + sampleRate = 16000 + numChannels = 1 + bitsPerSample = 16 +) + +func RecordAudio(f *os.File) error { + fmt.Print("Recording. Press any key to stop.\n\n") + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGTERM) + + portaudio.Initialize() + defer portaudio.Terminate() + + in := make([]int16, 64) + stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in) + if err != nil { + return err + } + defer stream.Close() + + err = stream.Start() + if err != nil { + return err + } + + // Write WAV header with placeholder sizes + writeWavHeader(f, sampleRate, numChannels, bitsPerSample) + + var totalSamples uint32 + + // Set up terminal input reading + oldState, err := term.MakeRaw(int(os.Stdin.Fd())) + if err != nil { + return err + } + defer term.Restore(int(os.Stdin.Fd()), oldState) + + // Create a channel to handle the stop signal + stop := make(chan struct{}) + + go func() { + _, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1)) + if err != nil { + fmt.Println("Error reading from stdin:", err) + return + } + // Send signal to stop recording + stop <- struct{}{} + }() + +loop: + for { + err = stream.Read() + if err != nil { + return err + } + + err = binary.Write(f, binary.LittleEndian, in) + if err != nil { + return err + } + totalSamples += uint32(len(in)) + + select { + case <-stop: + break loop + case <-sig: + break loop + default: + } + } + + err = stream.Stop() + if err != nil { + return err + } + + // Update WAV header with actual sizes + updateWavHeader(f, totalSamples, numChannels, bitsPerSample) + + return nil +} + +func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) { + subchunk1Size := 16 + audioFormat := 1 + byteRate := sampleRate * numChannels * (bitsPerSample / 8) + blockAlign := numChannels * (bitsPerSample / 8) + + // Write the RIFF header + f.Write([]byte("RIFF")) + binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size + f.Write([]byte("WAVE")) + + // Write the fmt subchunk + f.Write([]byte("fmt ")) + binary.Write(f, binary.LittleEndian, uint32(subchunk1Size)) + binary.Write(f, binary.LittleEndian, uint16(audioFormat)) + binary.Write(f, binary.LittleEndian, uint16(numChannels)) + binary.Write(f, binary.LittleEndian, uint32(sampleRate)) + binary.Write(f, binary.LittleEndian, uint32(byteRate)) + binary.Write(f, binary.LittleEndian, uint16(blockAlign)) + binary.Write(f, binary.LittleEndian, uint16(bitsPerSample)) + + // Write the data subchunk header + f.Write([]byte("data")) + binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size +} + +func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) { + fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)) + dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8) + + // Seek to the start of the file and write updated sizes + f.Seek(4, 0) + binary.Write(f, binary.LittleEndian, uint32(fileSize)) + + f.Seek(40, 0) + binary.Write(f, binary.LittleEndian, uint32(dataSize)) +} diff --git a/server/routes.go b/server/routes.go index 8972b55f..c8407e56 100644 --- a/server/routes.go +++ b/server/routes.go @@ -110,7 +110,12 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil } func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) { - modelPath := speech.Model + var modelPath string + if speech.Model == "" { + modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin" + } else { + modelPath = speech.Model + } // default to 5 minutes var sessionDuration time.Duration @@ -130,7 +135,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan er return } - whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server" + whisperServer := "/Users/royhan-ollama/.ollama/server" // Find an available port for whisper port := 0 @@ -1510,8 +1515,9 @@ func (s *Server) ProcessHandler(c *gin.Context) { } func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error { - if req.Model == "" { - return nil + slog.Info("processing audio") + if req == nil { + req = &api.WhisperRequest{} } portCh := make(chan int, 1) errCh := make(chan error, 1) @@ -1583,7 +1589,7 @@ func (s *Server) ChatHandler(c *gin.Context) { msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) } - if req.Speech != nil { + if req.Speech != nil || req.RunSpeech { if err := processAudio(c, s, msgs, req.Speech); err != nil { slog.Error("failed to process audio", "error", err) return