grpcurl/internal/testing/cmd/bankdemo/db.go

185 lines
4.8 KiB
Go

package main
import (
"fmt"
"sync"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
)
// In-memory database that is periodically saved to a JSON file.
type accounts struct {
AccountNumbersByCustomer map[string][]uint64
AccountsByNumber map[uint64]*account
AccountNumbers []uint64
Customers []string
LastAccountNum uint64
mu sync.RWMutex
}
func (a *accounts) openAccount(customer string, accountType Account_Type, initialBalanceCents int32) *Account {
a.mu.Lock()
defer a.mu.Unlock()
accountNums, ok := a.AccountNumbersByCustomer[customer]
if !ok {
// no accounts for this customer? it's a new customer
a.Customers = append(a.Customers, customer)
}
num := a.LastAccountNum + 1
a.LastAccountNum = num
a.AccountNumbers = append(a.AccountNumbers, num)
accountNums = append(accountNums, num)
a.AccountNumbersByCustomer[customer] = accountNums
var acct account
acct.AccountNumber = num
acct.Type = accountType
acct.BalanceCents = initialBalanceCents
acct.Transactions = append(acct.Transactions, &Transaction{
AccountNumber: num,
SeqNumber: 1,
Date: timestamppb.Now(),
AmountCents: initialBalanceCents,
Desc: "initial deposit",
})
a.AccountsByNumber[num] = &acct
return &acct.Account
}
func (a *accounts) closeAccount(customer string, accountNumber uint64) error {
a.mu.Lock()
defer a.mu.Unlock()
acctNums := a.AccountNumbersByCustomer[customer]
found := -1
for i, num := range acctNums {
if num == accountNumber {
found = i
break
}
}
if found == -1 {
return status.Errorf(codes.NotFound, "you have no account numbered %d", accountNumber)
}
acct := a.AccountsByNumber[accountNumber]
if acct.BalanceCents != 0 {
return status.Errorf(codes.FailedPrecondition, "account %d cannot be closed because it has a non-zero balance: %s", accountNumber, dollars(acct.BalanceCents))
}
for i, num := range a.AccountNumbers {
if num == accountNumber {
a.AccountNumbers = append(a.AccountNumbers[:i], a.AccountNumbers[i+1:]...)
break
}
}
a.AccountNumbersByCustomer[customer] = append(acctNums[:found], acctNums[found+1:]...)
delete(a.AccountsByNumber, accountNumber)
return nil
}
func (a *accounts) getAccount(customer string, accountNumber uint64) (*account, error) {
a.mu.RLock()
defer a.mu.RUnlock()
acctNums := a.AccountNumbersByCustomer[customer]
for _, num := range acctNums {
if num == accountNumber {
return a.AccountsByNumber[num], nil
}
}
return nil, status.Errorf(codes.NotFound, "you have no account numbered %d", accountNumber)
}
func (a *accounts) getAllAccounts(customer string) []*Account {
a.mu.RLock()
defer a.mu.RUnlock()
accountNums := a.AccountNumbersByCustomer[customer]
var accounts []*Account
for _, num := range accountNums {
accounts = append(accounts, &a.AccountsByNumber[num].Account)
}
return accounts
}
type account struct {
Account
Transactions []*Transaction
mu sync.RWMutex
}
func (a *account) getTransactions() []*Transaction {
a.mu.RLock()
defer a.mu.RUnlock()
return a.Transactions
}
func (a *account) newTransaction(amountCents int32, desc string) (newBalance int32, err error) {
a.mu.Lock()
defer a.mu.Unlock()
bal := a.BalanceCents + amountCents
if bal < 0 {
return 0, status.Errorf(codes.FailedPrecondition, "insufficient funds: cannot withdraw %s when balance is %s", dollars(amountCents), dollars(a.BalanceCents))
}
a.BalanceCents = bal
a.Transactions = append(a.Transactions, &Transaction{
AccountNumber: a.AccountNumber,
Date: timestamppb.Now(),
AmountCents: amountCents,
SeqNumber: uint64(len(a.Transactions) + 1),
Desc: desc,
})
return bal, nil
}
func (t *Transaction) MarshalJSON() ([]byte, error) {
return protojson.Marshal(t)
}
func (t *Transaction) UnmarshalJSON(b []byte) error {
return protojson.Unmarshal(b, t)
}
func (a *accounts) clone() *accounts {
var clone accounts
clone.AccountNumbersByCustomer = map[string][]uint64{}
clone.AccountsByNumber = map[uint64]*account{}
a.mu.RLock()
clone.Customers = a.Customers
a.mu.RUnlock()
for _, cust := range clone.Customers {
var acctNums []uint64
a.mu.RLock()
acctNums = a.AccountNumbersByCustomer[cust]
a.mu.RUnlock()
clone.AccountNumbersByCustomer[cust] = acctNums
clone.AccountNumbers = append(clone.AccountNumbers, acctNums...)
for _, acctNum := range acctNums {
a.mu.RLock()
acct := a.AccountsByNumber[acctNum]
a.mu.RUnlock()
acct.mu.RLock()
txns := acct.Transactions
acct.mu.RUnlock()
clone.AccountsByNumber[acctNum] = &account{Transactions: txns}
}
}
return &clone
}
func dollars(amountCents int32) string {
return fmt.Sprintf("$%02f", float64(amountCents)/100)
}