diff --git a/attestation/factory.go b/attestation/factory.go index abc0bb4d..a78bfabb 100644 --- a/attestation/factory.go +++ b/attestation/factory.go @@ -70,6 +70,12 @@ func (e ErrAttestationNotFound) Error() string { return fmt.Sprintf("attestation not found: %v", string(e)) } +type ErrAttestorNotFound string + +func (e ErrAttestorNotFound) Error() string { + return fmt.Sprintf("attestor not found: %v", string(e)) +} + func RegisterAttestation(name, predicateType string, run RunType, factoryFunc registry.FactoryFunc[Attestor], opts ...registry.Configurer) { registrationEntry := attestorRegistry.Register(name, factoryFunc, opts...) attestationsByType[predicateType] = registrationEntry @@ -86,14 +92,32 @@ func FactoryByName(name string) (registry.FactoryFunc[Attestor], bool) { return registrationEntry.Factory, ok } +func GetAttestor(nameOrType string) (Attestor, error) { + attestors, err := GetAttestors([]string{nameOrType}) + if err != nil { + return nil, err + } + + if len(attestors) == 0 { + return nil, ErrAttestorNotFound(nameOrType) + } + + return attestors[0], nil +} + +// Deprecated: use AddAttestors instead func Attestors(nameOrTypes []string) ([]Attestor, error) { + return GetAttestors(nameOrTypes) +} + +func GetAttestors(nameOrTypes []string) ([]Attestor, error) { attestors := make([]Attestor, 0) for _, nameOrType := range nameOrTypes { factory, ok := FactoryByName(nameOrType) if !ok { factory, ok = FactoryByType(nameOrType) if !ok { - return nil, ErrAttestationNotFound(nameOrType) + return nil, ErrAttestorNotFound(nameOrType) } }