Skip to content

Commit

Permalink
Undo rename of pipelines -> taskPipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
RJKeevil committed Dec 3, 2024
1 parent 8730bd1 commit 84ebc2c
Show file tree
Hide file tree
Showing 16 changed files with 187 additions and 187 deletions.
6 changes: 3 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"sync"

"github.com/knights-analytics/hugot/options"
"github.com/knights-analytics/hugot/pipelines"
"github.com/knights-analytics/hugot/pipelineBackends"

"github.com/mattn/go-isatty"
"github.com/urfave/cli/v2"
Expand Down Expand Up @@ -166,7 +166,7 @@ var runCommand = &cli.Command{
}
}

var pipe pipelines.Pipeline
var pipe pipelineBackends.Pipeline
switch pipelineType {
case "tokenClassification":
config := hugot.TokenClassificationConfig{
Expand Down Expand Up @@ -344,7 +344,7 @@ func writeOutputs(wg *sync.WaitGroup, processedChannel chan []byte, errorChannel
wg.Done()
}

func processWithPipeline(wg *sync.WaitGroup, inputChannel chan []input, processedChannel chan []byte, errorsChannel chan error, p pipelines.Pipeline) {
func processWithPipeline(wg *sync.WaitGroup, inputChannel chan []input, processedChannel chan []byte, errorsChannel chan error, p pipelineBackends.Pipeline) {
for inputBatch := range inputChannel {
inputStrings := make([]string, len(inputBatch))
for i := 0; i < len(inputBatch); i++ {
Expand Down
82 changes: 41 additions & 41 deletions hugot.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ import (
"slices"

"github.com/knights-analytics/hugot/options"
"github.com/knights-analytics/hugot/pipelineBackends"
"github.com/knights-analytics/hugot/pipelines"
"github.com/knights-analytics/hugot/taskPipelines"
)

// Session allows for the creation of new pipelines and holds the pipeline already created.
type Session struct {
featureExtractionPipelines pipelineMap[*taskPipelines.FeatureExtractionPipeline]
tokenClassificationPipelines pipelineMap[*taskPipelines.TokenClassificationPipeline]
textClassificationPipelines pipelineMap[*taskPipelines.TextClassificationPipeline]
zeroShotClassificationPipelines pipelineMap[*taskPipelines.ZeroShotClassificationPipeline]
models map[string]*pipelines.Model
featureExtractionPipelines pipelineMap[*pipelines.FeatureExtractionPipeline]
tokenClassificationPipelines pipelineMap[*pipelines.TokenClassificationPipeline]
textClassificationPipelines pipelineMap[*pipelines.TextClassificationPipeline]
zeroShotClassificationPipelines pipelineMap[*pipelines.ZeroShotClassificationPipeline]
models map[string]*pipelineBackends.Model
options *options.Options
environmentDestroy func() error
}
Expand All @@ -34,11 +34,11 @@ func newSession(runtime string, opts ...options.WithOption) (*Session, error) {
}

session := &Session{
featureExtractionPipelines: map[string]*taskPipelines.FeatureExtractionPipeline{},
textClassificationPipelines: map[string]*taskPipelines.TextClassificationPipeline{},
tokenClassificationPipelines: map[string]*taskPipelines.TokenClassificationPipeline{},
zeroShotClassificationPipelines: map[string]*taskPipelines.ZeroShotClassificationPipeline{},
models: map[string]*pipelines.Model{},
featureExtractionPipelines: map[string]*pipelines.FeatureExtractionPipeline{},
textClassificationPipelines: map[string]*pipelines.TextClassificationPipeline{},
tokenClassificationPipelines: map[string]*pipelines.TokenClassificationPipeline{},
zeroShotClassificationPipelines: map[string]*pipelines.ZeroShotClassificationPipeline{},
models: map[string]*pipelineBackends.Model{},
options: parsedOptions,
environmentDestroy: func() error {
return nil
Expand All @@ -48,7 +48,7 @@ func newSession(runtime string, opts ...options.WithOption) (*Session, error) {
return session, nil
}

type pipelineMap[T pipelines.Pipeline] map[string]T
type pipelineMap[T pipelineBackends.Pipeline] map[string]T

func (m pipelineMap[T]) GetStats() []string {
var stats []string
Expand All @@ -59,31 +59,31 @@ func (m pipelineMap[T]) GetStats() []string {
}

// FeatureExtractionConfig is the configuration for a feature extraction pipeline
type FeatureExtractionConfig = pipelines.PipelineConfig[*taskPipelines.FeatureExtractionPipeline]
type FeatureExtractionConfig = pipelineBackends.PipelineConfig[*pipelines.FeatureExtractionPipeline]

// FeatureExtractionOption is an option for a feature extraction pipeline
type FeatureExtractionOption = pipelines.PipelineOption[*taskPipelines.FeatureExtractionPipeline]
type FeatureExtractionOption = pipelineBackends.PipelineOption[*pipelines.FeatureExtractionPipeline]

// TextClassificationConfig is the configuration for a text classification pipeline
type TextClassificationConfig = pipelines.PipelineConfig[*taskPipelines.TextClassificationPipeline]
type TextClassificationConfig = pipelineBackends.PipelineConfig[*pipelines.TextClassificationPipeline]

// type ZSCConfig = pipelines.PipelineConfig[*pipelines.ZeroShotClassificationPipeline]

type ZeroShotClassificationConfig = pipelines.PipelineConfig[*taskPipelines.ZeroShotClassificationPipeline]
type ZeroShotClassificationConfig = pipelineBackends.PipelineConfig[*pipelines.ZeroShotClassificationPipeline]

// TextClassificationOption is an option for a text classification pipeline
type TextClassificationOption = pipelines.PipelineOption[*taskPipelines.TextClassificationPipeline]
type TextClassificationOption = pipelineBackends.PipelineOption[*pipelines.TextClassificationPipeline]

// TokenClassificationConfig is the configuration for a token classification pipeline
type TokenClassificationConfig = pipelines.PipelineConfig[*taskPipelines.TokenClassificationPipeline]
type TokenClassificationConfig = pipelineBackends.PipelineConfig[*pipelines.TokenClassificationPipeline]

// TokenClassificationOption is an option for a token classification pipeline
type TokenClassificationOption = pipelines.PipelineOption[*taskPipelines.TokenClassificationPipeline]
type TokenClassificationOption = pipelineBackends.PipelineOption[*pipelines.TokenClassificationPipeline]

// NewPipeline can be used to create a new pipeline of type T. The initialised pipeline will be returned and it
// will also be stored in the session object so that all created pipelines can be destroyed with session.Destroy()
// at once.
func NewPipeline[T pipelines.Pipeline](s *Session, pipelineConfig pipelines.PipelineConfig[T]) (T, error) {
func NewPipeline[T pipelineBackends.Pipeline](s *Session, pipelineConfig pipelineBackends.PipelineConfig[T]) (T, error) {
var pipeline T
if pipelineConfig.Name == "" {
return pipeline, errors.New("a name for the pipeline is required")
Expand All @@ -100,22 +100,22 @@ func NewPipeline[T pipelines.Pipeline](s *Session, pipelineConfig pipelines.Pipe
// Load model if it has not been loaded already
model, ok := s.models[pipelineConfig.ModelPath]
if !ok {
model = &pipelines.Model{
model = &pipelineBackends.Model{
Path: pipelineConfig.ModelPath,
OnnxFilename: pipelineConfig.OnnxFilename,
}

err := pipelines.LoadOnnxModelBytes(model)
err := pipelineBackends.LoadOnnxModelBytes(model)
if err != nil {
return pipeline, err
}

err = pipelines.CreateModelBackend(model, s.options)
err = pipelineBackends.CreateModelBackend(model, s.options)
if err != nil {
return pipeline, err
}

tkErr := pipelines.LoadTokenizer(model, s.options)
tkErr := pipelineBackends.LoadTokenizer(model, s.options)
if tkErr != nil {
return pipeline, tkErr
}
Expand All @@ -135,33 +135,33 @@ func NewPipeline[T pipelines.Pipeline](s *Session, pipelineConfig pipelines.Pipe
}

switch any(pipeline).(type) {
case *taskPipelines.TokenClassificationPipeline:
config := any(pipelineConfig).(pipelines.PipelineConfig[*taskPipelines.TokenClassificationPipeline])
pipelineInitialised, err := taskPipelines.NewTokenClassificationPipeline(config, s.options, model)
case *pipelines.TokenClassificationPipeline:
config := any(pipelineConfig).(pipelineBackends.PipelineConfig[*pipelines.TokenClassificationPipeline])
pipelineInitialised, err := pipelines.NewTokenClassificationPipeline(config, s.options, model)
if err != nil {
return pipeline, err
}
s.tokenClassificationPipelines[config.Name] = pipelineInitialised
pipeline = any(pipelineInitialised).(T)
case *taskPipelines.TextClassificationPipeline:
config := any(pipelineConfig).(pipelines.PipelineConfig[*taskPipelines.TextClassificationPipeline])
pipelineInitialised, err := taskPipelines.NewTextClassificationPipeline(config, s.options, model)
case *pipelines.TextClassificationPipeline:
config := any(pipelineConfig).(pipelineBackends.PipelineConfig[*pipelines.TextClassificationPipeline])
pipelineInitialised, err := pipelines.NewTextClassificationPipeline(config, s.options, model)
if err != nil {
return pipeline, err
}
s.textClassificationPipelines[config.Name] = pipelineInitialised
pipeline = any(pipelineInitialised).(T)
case *taskPipelines.FeatureExtractionPipeline:
config := any(pipelineConfig).(pipelines.PipelineConfig[*taskPipelines.FeatureExtractionPipeline])
pipelineInitialised, err := taskPipelines.NewFeatureExtractionPipeline(config, s.options, model)
case *pipelines.FeatureExtractionPipeline:
config := any(pipelineConfig).(pipelineBackends.PipelineConfig[*pipelines.FeatureExtractionPipeline])
pipelineInitialised, err := pipelines.NewFeatureExtractionPipeline(config, s.options, model)
if err != nil {
return pipeline, err
}
s.featureExtractionPipelines[config.Name] = pipelineInitialised
pipeline = any(pipelineInitialised).(T)
case *taskPipelines.ZeroShotClassificationPipeline:
config := any(pipelineConfig).(pipelines.PipelineConfig[*taskPipelines.ZeroShotClassificationPipeline])
pipelineInitialised, err := taskPipelines.NewZeroShotClassificationPipeline(config, s.options, model)
case *pipelines.ZeroShotClassificationPipeline:
config := any(pipelineConfig).(pipelineBackends.PipelineConfig[*pipelines.ZeroShotClassificationPipeline])
pipelineInitialised, err := pipelines.NewZeroShotClassificationPipeline(config, s.options, model)
if err != nil {
return pipeline, err
}
Expand All @@ -174,28 +174,28 @@ func NewPipeline[T pipelines.Pipeline](s *Session, pipelineConfig pipelines.Pipe
}

// GetPipeline can be used to retrieve a pipeline of type T with the given name from the session
func GetPipeline[T pipelines.Pipeline](s *Session, name string) (T, error) {
func GetPipeline[T pipelineBackends.Pipeline](s *Session, name string) (T, error) {
var pipeline T
switch any(pipeline).(type) {
case *taskPipelines.TokenClassificationPipeline:
case *pipelines.TokenClassificationPipeline:
p, ok := s.tokenClassificationPipelines[name]
if !ok {
return pipeline, &pipelineNotFoundError{pipelineName: name}
}
return any(p).(T), nil
case *taskPipelines.TextClassificationPipeline:
case *pipelines.TextClassificationPipeline:
p, ok := s.textClassificationPipelines[name]
if !ok {
return pipeline, &pipelineNotFoundError{pipelineName: name}
}
return any(p).(T), nil
case *taskPipelines.FeatureExtractionPipeline:
case *pipelines.FeatureExtractionPipeline:
p, ok := s.featureExtractionPipelines[name]
if !ok {
return pipeline, &pipelineNotFoundError{pipelineName: name}
}
return any(p).(T), nil
case *taskPipelines.ZeroShotClassificationPipeline:
case *pipelines.ZeroShotClassificationPipeline:
p, ok := s.zeroShotClassificationPipelines[name]
if !ok {
return pipeline, &pipelineNotFoundError{pipelineName: name}
Expand Down
Loading

0 comments on commit 84ebc2c

Please sign in to comment.