mirror of
https://github.com/tcsenpai/ollama.git
synced 2025-06-07 11:45:21 +00:00
cli
This commit is contained in:
parent
ad7e822883
commit
89f3bae306
@ -37,7 +37,7 @@ func (e StatusError) Error() string {
|
||||
type ImageData []byte
|
||||
|
||||
type WhisperRequest struct {
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Transcribe bool `json:"transcribe,omitempty"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
@ -116,6 +116,8 @@ type ChatRequest struct {
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
Speech *WhisperRequest `json:"speech,omitempty"`
|
||||
|
||||
RunSpeech bool `json:"run_speech,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
|
39
cmd/cmd.go
39
cmd/cmd.go
@ -38,6 +38,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/recorder"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@ -380,6 +381,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
speech, err := cmd.Flags().GetBool("speech")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if speech {
|
||||
return generateInteractiveAudio(cmd, opts)
|
||||
}
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
@ -862,6 +871,7 @@ type runOptions struct {
|
||||
Options map[string]interface{}
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Audio bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@ -970,6 +980,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if opts.Audio {
|
||||
req.RunSpeech = true
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
@ -1055,6 +1069,30 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
KeepAlive: opts.KeepAlive,
|
||||
}
|
||||
|
||||
speech, err := cmd.Flags().GetBool("speech")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create temp wav file with the recorder package
|
||||
if speech {
|
||||
tempFile, err := os.CreateTemp("", "recording-*.wav")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
fmt.Print("Speech Mode\n\n")
|
||||
|
||||
err = recorder.RecordAudio(tempFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Speech = &api.WhisperRequest{
|
||||
Audio: tempFile.Name(),
|
||||
}
|
||||
}
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
@ -1262,6 +1300,7 @@ func NewCLI() *cobra.Command {
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("speech", false, "Speech to text mode")
|
||||
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/recorder"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
@ -51,6 +52,40 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
|
||||
}
|
||||
|
||||
func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error {
|
||||
for {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
// create temp wav file with the recorder package
|
||||
tempFile, err := os.CreateTemp("", "recording-*.wav")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
err = recorder.RecordAudio(tempFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.StopAndClear()
|
||||
|
||||
newMessage := api.Message{Role: "user", Audio: tempFile.Name()}
|
||||
opts.Audio = true
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
opts.Messages = append(opts.Messages, *assistant)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
|
1
go.mod
1
go.mod
@ -19,6 +19,7 @@ require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
2
go.sum
2
go.sum
@ -115,6 +115,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
|
137
recorder/recorder.go
Normal file
137
recorder/recorder.go
Normal file
@ -0,0 +1,137 @@
|
||||
package recorder
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/gordonklaus/portaudio"
|
||||
)
|
||||
|
||||
const (
|
||||
sampleRate = 16000
|
||||
numChannels = 1
|
||||
bitsPerSample = 16
|
||||
)
|
||||
|
||||
func RecordAudio(f *os.File) error {
|
||||
fmt.Print("Recording. Press any key to stop.\n\n")
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
portaudio.Initialize()
|
||||
defer portaudio.Terminate()
|
||||
|
||||
in := make([]int16, 64)
|
||||
stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
err = stream.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write WAV header with placeholder sizes
|
||||
writeWavHeader(f, sampleRate, numChannels, bitsPerSample)
|
||||
|
||||
var totalSamples uint32
|
||||
|
||||
// Set up terminal input reading
|
||||
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer term.Restore(int(os.Stdin.Fd()), oldState)
|
||||
|
||||
// Create a channel to handle the stop signal
|
||||
stop := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1))
|
||||
if err != nil {
|
||||
fmt.Println("Error reading from stdin:", err)
|
||||
return
|
||||
}
|
||||
// Send signal to stop recording
|
||||
stop <- struct{}{}
|
||||
}()
|
||||
|
||||
loop:
|
||||
for {
|
||||
err = stream.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
totalSamples += uint32(len(in))
|
||||
|
||||
select {
|
||||
case <-stop:
|
||||
break loop
|
||||
case <-sig:
|
||||
break loop
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
err = stream.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update WAV header with actual sizes
|
||||
updateWavHeader(f, totalSamples, numChannels, bitsPerSample)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) {
|
||||
subchunk1Size := 16
|
||||
audioFormat := 1
|
||||
byteRate := sampleRate * numChannels * (bitsPerSample / 8)
|
||||
blockAlign := numChannels * (bitsPerSample / 8)
|
||||
|
||||
// Write the RIFF header
|
||||
f.Write([]byte("RIFF"))
|
||||
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size
|
||||
f.Write([]byte("WAVE"))
|
||||
|
||||
// Write the fmt subchunk
|
||||
f.Write([]byte("fmt "))
|
||||
binary.Write(f, binary.LittleEndian, uint32(subchunk1Size))
|
||||
binary.Write(f, binary.LittleEndian, uint16(audioFormat))
|
||||
binary.Write(f, binary.LittleEndian, uint16(numChannels))
|
||||
binary.Write(f, binary.LittleEndian, uint32(sampleRate))
|
||||
binary.Write(f, binary.LittleEndian, uint32(byteRate))
|
||||
binary.Write(f, binary.LittleEndian, uint16(blockAlign))
|
||||
binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
|
||||
|
||||
// Write the data subchunk header
|
||||
f.Write([]byte("data"))
|
||||
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size
|
||||
}
|
||||
|
||||
func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) {
|
||||
fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8))
|
||||
dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)
|
||||
|
||||
// Seek to the start of the file and write updated sizes
|
||||
f.Seek(4, 0)
|
||||
binary.Write(f, binary.LittleEndian, uint32(fileSize))
|
||||
|
||||
f.Seek(40, 0)
|
||||
binary.Write(f, binary.LittleEndian, uint32(dataSize))
|
||||
}
|
@ -110,7 +110,12 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
||||
}
|
||||
|
||||
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
|
||||
modelPath := speech.Model
|
||||
var modelPath string
|
||||
if speech.Model == "" {
|
||||
modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin"
|
||||
} else {
|
||||
modelPath = speech.Model
|
||||
}
|
||||
|
||||
// default to 5 minutes
|
||||
var sessionDuration time.Duration
|
||||
@ -130,7 +135,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan er
|
||||
return
|
||||
}
|
||||
|
||||
whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server"
|
||||
whisperServer := "/Users/royhan-ollama/.ollama/server"
|
||||
|
||||
// Find an available port for whisper
|
||||
port := 0
|
||||
@ -1510,8 +1515,9 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error {
|
||||
if req.Model == "" {
|
||||
return nil
|
||||
slog.Info("processing audio")
|
||||
if req == nil {
|
||||
req = &api.WhisperRequest{}
|
||||
}
|
||||
portCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
@ -1583,7 +1589,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
}
|
||||
|
||||
if req.Speech != nil {
|
||||
if req.Speech != nil || req.RunSpeech {
|
||||
if err := processAudio(c, s, msgs, req.Speech); err != nil {
|
||||
slog.Error("failed to process audio", "error", err)
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user