grpcurl/internal/testing/cmd/bankdemo/chat.go

467 lines
9.8 KiB
Go

package main
import (
"context"
"fmt"
"io"
"sync"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
)
// chatServer implements the Support gRPC service, for providing
// a capability to connect customers and support agents in real-time
// chat.
type chatServer struct {
UnimplementedSupportServer
chatsBySession map[string]*session
chatsAwaitingAgent []string
lastSession int32
mu sync.RWMutex
}
type session struct {
Session
active bool
cust *listener
agents map[string]*listener
mu sync.RWMutex
}
type listener struct {
ch chan<- *ChatEntry
ctx context.Context
}
func (l *listener) send(e *ChatEntry) {
select {
case l.ch <- e:
case <-l.ctx.Done():
}
}
func (s *session) copySession() *Session {
s.mu.RLock()
defer s.mu.RUnlock()
return &Session{
SessionId: s.SessionId,
CustomerName: s.Session.CustomerName,
History: s.Session.History,
}
}
func (s *chatServer) ChatCustomer(stream Support_ChatCustomerServer) error {
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
cust := getCustomer(ctx)
if cust == "" {
return status.Error(codes.Unauthenticated, codes.Unauthenticated.String())
}
var sess *session
var ch chan *ChatEntry
var chCancel context.CancelFunc
cleanup := func() {
if sess != nil {
sess.mu.Lock()
sess.cust = nil
sess.mu.Unlock()
chCancel()
close(ch)
go func() {
// drain channel to prevent deadlock
for range ch {
}
}()
}
}
defer cleanup()
for {
req, err := stream.Recv()
if err != nil {
if err == io.EOF {
return nil
}
return err
}
switch req := req.Req.(type) {
case *ChatCustomerRequest_Init:
if sess != nil {
return status.Errorf(codes.FailedPrecondition, "already called init, currently in chat session %q", sess.SessionId)
}
sessionID := req.Init.ResumeSessionId
if sessionID == "" {
sess, ch, chCancel = s.newSession(ctx, cust)
} else if sess, ch, chCancel = s.resumeSession(ctx, cust, sessionID); sess == nil {
return status.Errorf(codes.FailedPrecondition, "cannot resume session %q; it is not an open session", sessionID)
}
err := stream.Send(&ChatCustomerResponse{
Resp: &ChatCustomerResponse_Session{
Session: sess.copySession(),
},
})
if err != nil {
return err
}
// monitor the returned channel, sending incoming agent messages down the pipe
go func() {
for {
select {
case entry, ok := <-ch:
if !ok {
return
}
if e, ok := entry.Entry.(*ChatEntry_AgentMsg); ok {
stream.Send(&ChatCustomerResponse{
Resp: &ChatCustomerResponse_Msg{
Msg: e.AgentMsg,
},
})
}
case <-ctx.Done():
return
}
}
}()
case *ChatCustomerRequest_Msg:
if sess == nil {
return status.Errorf(codes.FailedPrecondition, "never called init, no chat session for message")
}
entry := &ChatEntry{
Date: timestamppb.Now(),
Entry: &ChatEntry_CustomerMsg{
CustomerMsg: req.Msg,
},
}
func() {
sess.mu.Lock()
sess.Session.History = append(sess.Session.History, entry)
sess.mu.Unlock()
sess.mu.RLock()
defer sess.mu.RUnlock()
for _, l := range sess.agents {
l.send(entry)
}
}()
case *ChatCustomerRequest_HangUp:
if sess == nil {
return status.Errorf(codes.FailedPrecondition, "never called init, no chat session to hang up")
}
s.closeSession(sess)
cleanup()
sess = nil
default:
return status.Error(codes.InvalidArgument, "unknown request type")
}
}
}
func (s *chatServer) ChatAgent(stream Support_ChatAgentServer) error {
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
agent := getAgent(ctx)
if agent == "" {
return status.Error(codes.Unauthenticated, codes.Unauthenticated.String())
}
var sess *session
var ch chan *ChatEntry
var chCancel context.CancelFunc
cleanup := func() {
if sess != nil {
sess.mu.Lock()
delete(sess.agents, agent)
if len(sess.agents) == 0 {
s.mu.Lock()
s.chatsAwaitingAgent = append(s.chatsAwaitingAgent, sess.SessionId)
s.mu.Unlock()
}
sess.mu.Unlock()
chCancel()
close(ch)
go func() {
// drain channel to prevent deadlock
for range ch {
}
}()
}
}
defer cleanup()
checkSession := func() {
// see if session was concurrently closed
if sess != nil {
sess.mu.RLock()
active := sess.active
sess.mu.RUnlock()
if !active {
cleanup()
sess = nil
}
}
}
for {
req, err := stream.Recv()
if err != nil {
if err == io.EOF {
return nil
}
return err
}
checkSession()
switch req := req.Req.(type) {
case *ChatAgentRequest_Accept:
if sess != nil {
return status.Errorf(codes.FailedPrecondition, "already called accept, currently in chat session %q", sess.SessionId)
}
sess, ch, chCancel = s.acceptSession(ctx, agent, req.Accept.SessionId)
if sess == nil {
return status.Errorf(codes.FailedPrecondition, "no session to accept")
}
err := stream.Send(&ChatAgentResponse{
Resp: &ChatAgentResponse_AcceptedSession{
AcceptedSession: sess.copySession(),
},
})
if err != nil {
return err
}
// monitor the returned channel, sending incoming agent messages down the pipe
go func() {
for {
select {
case entry, ok := <-ch:
if !ok {
return
}
if entry == nil {
stream.Send(&ChatAgentResponse{
Resp: &ChatAgentResponse_SessionEnded{
SessionEnded: Void_VOID,
},
})
continue
}
if agentMsg, ok := entry.Entry.(*ChatEntry_AgentMsg); ok {
if agentMsg.AgentMsg.AgentName == agent {
continue
}
}
stream.Send(&ChatAgentResponse{
Resp: &ChatAgentResponse_Msg{
Msg: entry,
},
})
case <-ctx.Done():
return
}
}
}()
case *ChatAgentRequest_Msg:
if sess == nil {
return status.Errorf(codes.FailedPrecondition, "never called accept, no chat session for message")
}
entry := &ChatEntry{
Date: timestamppb.Now(),
Entry: &ChatEntry_AgentMsg{
AgentMsg: &AgentMessage{
AgentName: agent,
Msg: req.Msg,
},
},
}
active := true
func() {
sess.mu.Lock()
active = sess.active
if active {
sess.Session.History = append(sess.Session.History, entry)
}
sess.mu.Unlock()
if !active {
return
}
sess.mu.RLock()
defer sess.mu.RUnlock()
if sess.cust != nil {
sess.cust.send(entry)
}
for otherAgent, l := range sess.agents {
if otherAgent == agent {
continue
}
l.send(entry)
}
}()
if !active {
return status.Errorf(codes.FailedPrecondition, "customer hung up on chat session %s", sess.SessionId)
}
case *ChatAgentRequest_LeaveSession:
if sess == nil {
return status.Errorf(codes.FailedPrecondition, "never called init, no chat session to hang up")
}
s.closeSession(sess)
cleanup()
sess = nil
default:
return status.Error(codes.InvalidArgument, "unknown request type")
}
}
}
func (s *chatServer) newSession(ctx context.Context, cust string) (*session, chan *ChatEntry, context.CancelFunc) {
s.mu.Lock()
defer s.mu.Unlock()
s.lastSession++
id := fmt.Sprintf("%06d", s.lastSession)
s.chatsAwaitingAgent = append(s.chatsAwaitingAgent, id)
ch := make(chan *ChatEntry, 1)
ctx, cancel := context.WithCancel(ctx)
l := &listener{
ch: ch,
ctx: ctx,
}
sess := session{
active: true,
Session: Session{
SessionId: id,
CustomerName: cust,
},
cust: l,
}
s.chatsBySession[id] = &sess
return &sess, ch, cancel
}
func (s *chatServer) resumeSession(ctx context.Context, cust, sessionID string) (*session, chan *ChatEntry, context.CancelFunc) {
s.mu.Lock()
defer s.mu.Unlock()
sess := s.chatsBySession[sessionID]
if sess.CustomerName != cust {
// customer cannot join chat that they did not start
return nil, nil, nil
}
if !sess.active {
// chat has been closed
return nil, nil, nil
}
if sess.cust != nil {
// customer is active in the chat in another stream!
return nil, nil, nil
}
ch := make(chan *ChatEntry, 1)
ctx, cancel := context.WithCancel(ctx)
l := &listener{
ch: ch,
ctx: ctx,
}
sess.cust = l
return sess, ch, cancel
}
func (s *chatServer) closeSession(sess *session) {
active := true
func() {
sess.mu.Lock()
active = sess.active
sess.active = false
sess.mu.Unlock()
if !active {
// already closed
return
}
sess.mu.RLock()
defer sess.mu.RUnlock()
for _, l := range sess.agents {
l.send(nil)
}
}()
if !active {
// already closed
return
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.chatsBySession, sess.SessionId)
for i, id := range s.chatsAwaitingAgent {
if id == sess.SessionId {
s.chatsAwaitingAgent = append(s.chatsAwaitingAgent[:i], s.chatsAwaitingAgent[i+1:]...)
break
}
}
}
func (s *chatServer) acceptSession(ctx context.Context, agent, sessionID string) (*session, chan *ChatEntry, context.CancelFunc) {
var sess *session
func() {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.chatsAwaitingAgent) == 0 {
return
}
if sessionID == "" {
sessionID = s.chatsAwaitingAgent[0]
s.chatsAwaitingAgent = s.chatsAwaitingAgent[1:]
} else {
found := false
for i, id := range s.chatsAwaitingAgent {
if id == sessionID {
found = true
s.chatsAwaitingAgent = append(s.chatsAwaitingAgent[:i], s.chatsAwaitingAgent[i+1:]...)
break
}
}
if !found {
return
}
}
sess = s.chatsBySession[sessionID]
}()
if sess == nil {
return nil, nil, nil
}
ch := make(chan *ChatEntry, 1)
ctx, cancel := context.WithCancel(ctx)
l := &listener{
ch: ch,
ctx: ctx,
}
sess.mu.Lock()
if sess.agents == nil {
sess.agents = map[string]*listener{}
}
sess.agents[agent] = l
sess.mu.Unlock()
return sess, ch, cancel
}