diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..e980bba --- /dev/null +++ b/auth.go @@ -0,0 +1,53 @@ +package email + +import ( + "bytes" + "errors" + "fmt" + "net/smtp" +) + +func LoginAuth(username, password, host string) smtp.Auth { + return &loginAuth{username, password, host} +} + +// loginAuth is an smtp.Auth that implements the LOGIN authentication mechanism. +type loginAuth struct { + username string + password string + host string +} + +func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + if !server.TLS { + advertised := false + for _, mechanism := range server.Auth { + if mechanism == "LOGIN" { + advertised = true + break + } + } + if !advertised { + return "", nil, errors.New("gomail: unencrypted connection") + } + } + if server.Name != a.host { + return "", nil, errors.New("gomail: wrong host name") + } + return "LOGIN", nil, nil +} + +func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if !more { + return nil, nil + } + + switch { + case bytes.Equal(fromServer, []byte("Username:")): + return []byte(a.username), nil + case bytes.Equal(fromServer, []byte("Password:")): + return []byte(a.password), nil + default: + return nil, fmt.Errorf("gomail: unexpected server challenge: %s", fromServer) + } +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..0dfc98c --- /dev/null +++ b/auth_test.go @@ -0,0 +1,100 @@ +package email + +import ( + "net/smtp" + "testing" +) + +const ( + testUser = "user" + testPwd = "pwd" + testHost = "smtp.example.com" +) + +type authTest struct { + auths []string + challenges []string + tls bool + wantData []string + wantError bool +} + +func TestNoAdvertisement(t *testing.T) { + testLoginAuth(t, &authTest{ + auths: []string{}, + tls: false, + wantError: true, + }) +} + +func TestNoAdvertisementTLS(t *testing.T) { + testLoginAuth(t, &authTest{ + auths: []string{}, + challenges: []string{"Username:", "Password:"}, + tls: true, + wantData: []string{"", testUser, testPwd}, + }) +} + +func TestLogin(t *testing.T) { + testLoginAuth(t, &authTest{ + auths: []string{"PLAIN", "LOGIN"}, + challenges: []string{"Username:", "Password:"}, + tls: false, + wantData: []string{"", testUser, testPwd}, + }) +} + +func TestLoginTLS(t *testing.T) { + testLoginAuth(t, &authTest{ + auths: []string{"LOGIN"}, + challenges: []string{"Username:", "Password:"}, + tls: true, + wantData: []string{"", testUser, testPwd}, + }) +} + +func testLoginAuth(t *testing.T, test *authTest) { + auth := &loginAuth{ + username: testUser, + password: testPwd, + host: testHost, + } + server := &smtp.ServerInfo{ + Name: testHost, + TLS: test.tls, + Auth: test.auths, + } + proto, toServer, err := auth.Start(server) + if err != nil && !test.wantError { + t.Fatalf("loginAuth.Start(): %v", err) + } + if err != nil && test.wantError { + return + } + if proto != "LOGIN" { + t.Errorf("invalid protocol, got %q, want LOGIN", proto) + } + + i := 0 + got := string(toServer) + if got != test.wantData[i] { + t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i]) + } + + for _, challenge := range test.challenges { + i++ + if i >= len(test.wantData) { + t.Fatalf("unexpected challenge: %q", challenge) + } + + toServer, err = auth.Next([]byte(challenge), true) + if err != nil { + t.Fatalf("loginAuth.Auth(): %v", err) + } + got = string(toServer) + if got != test.wantData[i] { + t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i]) + } + } +}