如何使用 gorm 进行单元测试
How to do Unit Testing with gorm
我是 Go
和 unit test
的新人。在我的项目中,我使用 Go
和 gorm
并连接 mysql
数据库。
我的问题是如何对我的代码进行单元测试:
我的代码如下(main.go):
package main
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
type Jobs struct {
JobID uint `json: "jobId" gorm:"primary_key;auto_increment"`
SourcePath string `json: "sourcePath"`
Priority int64 `json: "priority"`
InternalPriority string `json: "internalPriority"`
ExecutionEnvironmentID string `json: "executionEnvironmentID"`
}
type ExecutionEnvironment struct {
ID uint `json: "id" gorm:"primary_key;auto_increment"`
ExecutionEnvironmentId string `json: "executionEnvironmentID"`
CloudProviderType string `json: "cloudProviderType"`
InfrastructureType string `json: "infrastructureType"`
CloudRegion string `json: "cloudRegion"`
CreatedAt time.Time `json: "createdAt"`
}
var db *gorm.DB
func initDB() {
var err error
dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
db, err = gorm.Open("mysql", dataSourceName)
if err != nil {
fmt.Println(err)
panic("failed to connect database")
}
//db.Exec("CREATE DATABASE test")
db.LogMode(true)
db.Exec("USE test")
db.AutoMigrate(&Jobs{}, &ExecutionEnvironment{})
}
func GetAllJobs(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Get All Jobs function")
var jobs []Jobs
if err := db.Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Find(&jobs).Error; err != nil {
fmt.Println(err)
}
fmt.Println()
if len(jobs) == 0 {
json.NewEncoder(w).Encode("No data found")
} else {
json.NewEncoder(w).Encode(jobs)
}
}
// create job
func createJob(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Create Jobs function")
var jobs Jobs
json.NewDecoder(r.Body).Decode(&jobs)
db.Create(&jobs)
json.NewEncoder(w).Encode(jobs)
}
// get job by id
func GetJobById(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
params := mux.Vars(r)
jobId := params["jobId"]
//var job []Jobs
//db.Preload("Items").First(&job, jobId)
var jobs []Jobs
var executionEnvironments []ExecutionEnvironment
if err := db.Table("jobs").Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("job_id =?", jobId).Find(&jobs).Scan(&executionEnvironments).Error; err != nil {
fmt.Println(err)
}
if len(jobs) == 0 {
json.NewEncoder(w).Encode("No data found")
} else {
json.NewEncoder(w).Encode(jobs)
}
}
// Delete Job By Id
func DeleteJobById(w http.ResponseWriter, r *http.Request) {
params := mux.Vars(r)
jobId := params["jobId"]
// check data
var job []Jobs
db.Table("jobs").Select("jobs.*").Where("job_id=?", jobId).Find(&job)
if len(job) == 0 {
json.NewEncoder(w).Encode("Invalid JobId")
} else {
id64, _ := strconv.ParseUint(jobId, 10, 64)
idToDelete := uint(id64)
db.Where("job_id = ?", idToDelete).Delete(&Jobs{})
//db.Where("jobId = ?", idToDelete).Delete(&ExecutionEnvironment{})
json.NewEncoder(w).Encode("Job deleted successfully")
w.WriteHeader(http.StatusNoContent)
}
}
// create Execution Environments
func createEnvironments(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Create Execution Environments function")
var executionEnvironments ExecutionEnvironment
json.NewDecoder(r.Body).Decode(&executionEnvironments)
db.Create(&executionEnvironments)
json.NewEncoder(w).Encode(executionEnvironments)
}
// Get Job Cloud Region
func GetJobCloudRegion(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Get Job Cloud Region function")
params := mux.Vars(r)
jobId := params["jobId"]
//var jobs []Jobs
var executionEnvironment []ExecutionEnvironment
db.Table("jobs").Select("execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("jobs.job_id =?", jobId).Find(&executionEnvironment)
var pUuid []string
for _, uuid := range executionEnvironment {
pUuid = append(pUuid, uuid.CloudRegion)
}
json.NewEncoder(w).Encode(pUuid)
}
func main() {
// router
router := mux.NewRouter()
// Access URL
router.HandleFunc("/GetAllJobs", GetAllJobs).Methods("GET")
router.HandleFunc("/createJob", createJob).Methods("POST")
router.HandleFunc("/GetJobById/{jobId}", GetJobById).Methods("GET")
router.HandleFunc("/DeleteJobById/{jobId}", DeleteJobById).Methods("DELETE")
router.HandleFunc("/createEnvironments", createEnvironments).Methods("POST")
router.HandleFunc("/GetJobCloudRegion/{jobId}", GetJobCloudRegion).Methods("GET")
// Initialize db connection
initDB()
// config port
fmt.Printf("Starting server at 8000 \n")
http.ListenAndServe(":8000", router)
}
我尝试在下面创建单元测试文件,但它不是 运行 它显示如下
main_test.go:
package main
import (
"log"
"os"
"testing"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
func TestinitDB(m *testing.M) {
dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
db, err := gorm.Open("mysql", dataSourceName)
if err != nil {
log.Fatal("failed to connect database")
}
//db.Exec("CREATE DATABASE test")
db.LogMode(true)
db.Exec("USE test111")
os.Exit(m.Run())
}
请帮我写单元测试文件
“如何进行单元测试”是一个非常宽泛的问题,因为它取决于您要测试的内容。在您的示例中,您正在使用与数据库的远程连接,这通常是在单元测试中被模拟的。目前尚不清楚这是否是您要查找的内容,也不是必须这样做。通过看到您使用不同的数据库,我希望您的意图不是嘲笑。
首先查看 post that has already answered your question around how TestMain
并且 testing.M
可以正常工作。
您的代码当前所做的(如果您的测试名称被正确命名为 TestMain
)是围绕您的其他测试添加一个方法来进行设置和拆卸,但是您没有任何其他测试要做使用此设置和拆卸,您将得到结果 no tests to run
.
这不是你问题的一部分,但我建议尽量避免 testing.M
,直到你对测试 Go 代码有信心为止。使用 testing.T
并测试单独的单元可能更容易理解。您可以通过在测试中调用 initDB()
并使初始化程序接受参数来实现几乎相同的事情。
func initDB(dbToUse string) {
// ...
db.Exec("USE "+dbToUse)
}
然后您将从您的主文件调用 initDB("test")
并从您的测试调用 initDB("test111")
。
您可以在 pkg.go.dev/testing 阅读有关 Go 的测试包的信息,您还可以在其中找到 testing.T
和 testing.M
之间的差异。
这是一个较短的示例,其中包含一些不需要任何设置或拆卸的基本测试,并且使用 testing.T
而不是 testing.M
。
main.go
package main
import "fmt"
func main() {
fmt.Println(add(1, 2))
}
func add(a, b int) int {
return a + b
}
main_test.go
package main
import "testing"
func TestAdd(t *testing.T) {
t.Run("add 2 + 2", func(t *testing.T) {
want := 4
// Call the function you want to test.
got := add(2, 2)
// Assert that you got your expected response
if got != want {
t.Fail()
}
})
}
此测试将测试您的方法 add
并在您将 2, 2
作为参数传递时确保它 returns 正确的值。 t.Run
的使用是可选的,但它会为您创建一个子测试,这使得阅读输出更容易一些。
由于您在包级别进行测试,因此如果您不使用三点格式递归地包括每个包,则需要指定要测试的包。
要运行上面示例中的测试,请指定您的包和-v
以获得详细输出。
$ go test ./ -v
=== RUN TestAdd
=== RUN TestAdd/add_2_+_2
--- PASS: TestAdd (0.00s)
--- PASS: TestAdd/add_2_+_2 (0.00s)
PASS
ok x (cached)
围绕这个主题还有很多东西需要学习,例如测试框架和测试模式。例如,测试框架 testify
helps you do assertions and prints nice output when tests fail and table driven tests 是 Go 中非常常见的模式。
您还在编写一个 HTTP 服务器,它通常需要额外的测试设置才能正确测试。幸运的是,标准库中的 http
包附带了一个名为 httptest
的子包,它可以帮助您记录外部请求或为外部请求启动本地服务器。您还可以通过使用手动构造的请求直接调用处理程序来测试您的处理程序。
看起来像这样。
func TestSomeHandler(t *testing.T) {
// Create a request to pass to our handler. We don't have any query parameters for now, so we'll
// pass 'nil' as the third parameter.
req, err := http.NewRequest("GET", "/some-endpoint", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(SomeHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
现在,测试您的一些代码。我们可以 运行 init 方法并使用响应记录器调用您的任何服务。
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestGetAllJobs(t *testing.T) {
// Initialize the DB
initDB("test111")
req, err := http.NewRequest("GET", "/GetAllJobs", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := http.HandlerFunc(GetAllJobs)
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
var response []Jobs
if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil {
t.Errorf("got invalid response, expected list of jobs, got: %v", rr.Body.String())
}
if len(response) < 1 {
t.Errorf("expected at least 1 job, got %v", len(response))
}
for _, job := range response {
if job.SourcePath == "" {
t.Errorf("expected job id %d to have a source path, was empty", job.JobID)
}
}
}
你可以使用 go-sqlmock:
package main
import (
"database/sql"
"regexp"
"testing"
"gopkg.in/DATA-DOG/go-sqlmock.v1"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type Student struct {
//*gorm.Model
Name string
ID string
}
type v2Suite struct {
db *gorm.DB
mock sqlmock.Sqlmock
student Student
}
func TestGORMV2(t *testing.T) {
s := &v2Suite{}
var (
db *sql.DB
err error
)
db, s.mock, err = sqlmock.New()
if err != nil {
t.Errorf("Failed to open mock sql db, got error: %v", err)
}
if db == nil {
t.Error("mock db is null")
}
if s.mock == nil {
t.Error("sqlmock is null")
}
dialector := postgres.New(postgres.Config{
DSN: "sqlmock_db_0",
DriverName: "postgres",
Conn: db,
PreferSimpleProtocol: true,
})
s.db, err = gorm.Open(dialector, &gorm.Config{})
if err != nil {
t.Errorf("Failed to open gorm v2 db, got error: %v", err)
}
if s.db == nil {
t.Error("gorm db is null")
}
s.student = Student{
ID: "123456",
Name: "Test 1",
}
defer db.Close()
s.mock.MatchExpectationsInOrder(false)
s.mock.ExpectBegin()
s.mock.ExpectQuery(regexp.QuoteMeta(
`INSERT INTO "students" ("id","name")
VALUES (,) RETURNING "students"."id"`)).
WithArgs(s.student.ID, s.student.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(s.student.ID))
s.mock.ExpectCommit()
if err = s.db.Create(&s.student).Error; err != nil {
t.Errorf("Failed to insert to gorm db, got error: %v", err)
}
err = s.mock.ExpectationsWereMet()
if err != nil {
t.Errorf("Failed to meet expectations, got error: %v", err)
}
}
我是 Go
和 unit test
的新人。在我的项目中,我使用 Go
和 gorm
并连接 mysql
数据库。
我的问题是如何对我的代码进行单元测试:
我的代码如下(main.go):
package main
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
type Jobs struct {
JobID uint `json: "jobId" gorm:"primary_key;auto_increment"`
SourcePath string `json: "sourcePath"`
Priority int64 `json: "priority"`
InternalPriority string `json: "internalPriority"`
ExecutionEnvironmentID string `json: "executionEnvironmentID"`
}
type ExecutionEnvironment struct {
ID uint `json: "id" gorm:"primary_key;auto_increment"`
ExecutionEnvironmentId string `json: "executionEnvironmentID"`
CloudProviderType string `json: "cloudProviderType"`
InfrastructureType string `json: "infrastructureType"`
CloudRegion string `json: "cloudRegion"`
CreatedAt time.Time `json: "createdAt"`
}
var db *gorm.DB
func initDB() {
var err error
dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
db, err = gorm.Open("mysql", dataSourceName)
if err != nil {
fmt.Println(err)
panic("failed to connect database")
}
//db.Exec("CREATE DATABASE test")
db.LogMode(true)
db.Exec("USE test")
db.AutoMigrate(&Jobs{}, &ExecutionEnvironment{})
}
func GetAllJobs(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Get All Jobs function")
var jobs []Jobs
if err := db.Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Find(&jobs).Error; err != nil {
fmt.Println(err)
}
fmt.Println()
if len(jobs) == 0 {
json.NewEncoder(w).Encode("No data found")
} else {
json.NewEncoder(w).Encode(jobs)
}
}
// create job
func createJob(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Create Jobs function")
var jobs Jobs
json.NewDecoder(r.Body).Decode(&jobs)
db.Create(&jobs)
json.NewEncoder(w).Encode(jobs)
}
// get job by id
func GetJobById(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
params := mux.Vars(r)
jobId := params["jobId"]
//var job []Jobs
//db.Preload("Items").First(&job, jobId)
var jobs []Jobs
var executionEnvironments []ExecutionEnvironment
if err := db.Table("jobs").Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("job_id =?", jobId).Find(&jobs).Scan(&executionEnvironments).Error; err != nil {
fmt.Println(err)
}
if len(jobs) == 0 {
json.NewEncoder(w).Encode("No data found")
} else {
json.NewEncoder(w).Encode(jobs)
}
}
// Delete Job By Id
func DeleteJobById(w http.ResponseWriter, r *http.Request) {
params := mux.Vars(r)
jobId := params["jobId"]
// check data
var job []Jobs
db.Table("jobs").Select("jobs.*").Where("job_id=?", jobId).Find(&job)
if len(job) == 0 {
json.NewEncoder(w).Encode("Invalid JobId")
} else {
id64, _ := strconv.ParseUint(jobId, 10, 64)
idToDelete := uint(id64)
db.Where("job_id = ?", idToDelete).Delete(&Jobs{})
//db.Where("jobId = ?", idToDelete).Delete(&ExecutionEnvironment{})
json.NewEncoder(w).Encode("Job deleted successfully")
w.WriteHeader(http.StatusNoContent)
}
}
// create Execution Environments
func createEnvironments(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Create Execution Environments function")
var executionEnvironments ExecutionEnvironment
json.NewDecoder(r.Body).Decode(&executionEnvironments)
db.Create(&executionEnvironments)
json.NewEncoder(w).Encode(executionEnvironments)
}
// Get Job Cloud Region
func GetJobCloudRegion(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Get Job Cloud Region function")
params := mux.Vars(r)
jobId := params["jobId"]
//var jobs []Jobs
var executionEnvironment []ExecutionEnvironment
db.Table("jobs").Select("execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("jobs.job_id =?", jobId).Find(&executionEnvironment)
var pUuid []string
for _, uuid := range executionEnvironment {
pUuid = append(pUuid, uuid.CloudRegion)
}
json.NewEncoder(w).Encode(pUuid)
}
func main() {
// router
router := mux.NewRouter()
// Access URL
router.HandleFunc("/GetAllJobs", GetAllJobs).Methods("GET")
router.HandleFunc("/createJob", createJob).Methods("POST")
router.HandleFunc("/GetJobById/{jobId}", GetJobById).Methods("GET")
router.HandleFunc("/DeleteJobById/{jobId}", DeleteJobById).Methods("DELETE")
router.HandleFunc("/createEnvironments", createEnvironments).Methods("POST")
router.HandleFunc("/GetJobCloudRegion/{jobId}", GetJobCloudRegion).Methods("GET")
// Initialize db connection
initDB()
// config port
fmt.Printf("Starting server at 8000 \n")
http.ListenAndServe(":8000", router)
}
我尝试在下面创建单元测试文件,但它不是 运行 它显示如下
main_test.go:
package main
import (
"log"
"os"
"testing"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
func TestinitDB(m *testing.M) {
dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
db, err := gorm.Open("mysql", dataSourceName)
if err != nil {
log.Fatal("failed to connect database")
}
//db.Exec("CREATE DATABASE test")
db.LogMode(true)
db.Exec("USE test111")
os.Exit(m.Run())
}
请帮我写单元测试文件
“如何进行单元测试”是一个非常宽泛的问题,因为它取决于您要测试的内容。在您的示例中,您正在使用与数据库的远程连接,这通常是在单元测试中被模拟的。目前尚不清楚这是否是您要查找的内容,也不是必须这样做。通过看到您使用不同的数据库,我希望您的意图不是嘲笑。
首先查看 TestMain
并且 testing.M
可以正常工作。
您的代码当前所做的(如果您的测试名称被正确命名为 TestMain
)是围绕您的其他测试添加一个方法来进行设置和拆卸,但是您没有任何其他测试要做使用此设置和拆卸,您将得到结果 no tests to run
.
这不是你问题的一部分,但我建议尽量避免 testing.M
,直到你对测试 Go 代码有信心为止。使用 testing.T
并测试单独的单元可能更容易理解。您可以通过在测试中调用 initDB()
并使初始化程序接受参数来实现几乎相同的事情。
func initDB(dbToUse string) {
// ...
db.Exec("USE "+dbToUse)
}
然后您将从您的主文件调用 initDB("test")
并从您的测试调用 initDB("test111")
。
您可以在 pkg.go.dev/testing 阅读有关 Go 的测试包的信息,您还可以在其中找到 testing.T
和 testing.M
之间的差异。
这是一个较短的示例,其中包含一些不需要任何设置或拆卸的基本测试,并且使用 testing.T
而不是 testing.M
。
main.go
package main
import "fmt"
func main() {
fmt.Println(add(1, 2))
}
func add(a, b int) int {
return a + b
}
main_test.go
package main
import "testing"
func TestAdd(t *testing.T) {
t.Run("add 2 + 2", func(t *testing.T) {
want := 4
// Call the function you want to test.
got := add(2, 2)
// Assert that you got your expected response
if got != want {
t.Fail()
}
})
}
此测试将测试您的方法 add
并在您将 2, 2
作为参数传递时确保它 returns 正确的值。 t.Run
的使用是可选的,但它会为您创建一个子测试,这使得阅读输出更容易一些。
由于您在包级别进行测试,因此如果您不使用三点格式递归地包括每个包,则需要指定要测试的包。
要运行上面示例中的测试,请指定您的包和-v
以获得详细输出。
$ go test ./ -v
=== RUN TestAdd
=== RUN TestAdd/add_2_+_2
--- PASS: TestAdd (0.00s)
--- PASS: TestAdd/add_2_+_2 (0.00s)
PASS
ok x (cached)
围绕这个主题还有很多东西需要学习,例如测试框架和测试模式。例如,测试框架 testify
helps you do assertions and prints nice output when tests fail and table driven tests 是 Go 中非常常见的模式。
您还在编写一个 HTTP 服务器,它通常需要额外的测试设置才能正确测试。幸运的是,标准库中的 http
包附带了一个名为 httptest
的子包,它可以帮助您记录外部请求或为外部请求启动本地服务器。您还可以通过使用手动构造的请求直接调用处理程序来测试您的处理程序。
看起来像这样。
func TestSomeHandler(t *testing.T) {
// Create a request to pass to our handler. We don't have any query parameters for now, so we'll
// pass 'nil' as the third parameter.
req, err := http.NewRequest("GET", "/some-endpoint", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(SomeHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
现在,测试您的一些代码。我们可以 运行 init 方法并使用响应记录器调用您的任何服务。
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestGetAllJobs(t *testing.T) {
// Initialize the DB
initDB("test111")
req, err := http.NewRequest("GET", "/GetAllJobs", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := http.HandlerFunc(GetAllJobs)
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
var response []Jobs
if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil {
t.Errorf("got invalid response, expected list of jobs, got: %v", rr.Body.String())
}
if len(response) < 1 {
t.Errorf("expected at least 1 job, got %v", len(response))
}
for _, job := range response {
if job.SourcePath == "" {
t.Errorf("expected job id %d to have a source path, was empty", job.JobID)
}
}
}
你可以使用 go-sqlmock:
package main
import (
"database/sql"
"regexp"
"testing"
"gopkg.in/DATA-DOG/go-sqlmock.v1"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type Student struct {
//*gorm.Model
Name string
ID string
}
type v2Suite struct {
db *gorm.DB
mock sqlmock.Sqlmock
student Student
}
func TestGORMV2(t *testing.T) {
s := &v2Suite{}
var (
db *sql.DB
err error
)
db, s.mock, err = sqlmock.New()
if err != nil {
t.Errorf("Failed to open mock sql db, got error: %v", err)
}
if db == nil {
t.Error("mock db is null")
}
if s.mock == nil {
t.Error("sqlmock is null")
}
dialector := postgres.New(postgres.Config{
DSN: "sqlmock_db_0",
DriverName: "postgres",
Conn: db,
PreferSimpleProtocol: true,
})
s.db, err = gorm.Open(dialector, &gorm.Config{})
if err != nil {
t.Errorf("Failed to open gorm v2 db, got error: %v", err)
}
if s.db == nil {
t.Error("gorm db is null")
}
s.student = Student{
ID: "123456",
Name: "Test 1",
}
defer db.Close()
s.mock.MatchExpectationsInOrder(false)
s.mock.ExpectBegin()
s.mock.ExpectQuery(regexp.QuoteMeta(
`INSERT INTO "students" ("id","name")
VALUES (,) RETURNING "students"."id"`)).
WithArgs(s.student.ID, s.student.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(s.student.ID))
s.mock.ExpectCommit()
if err = s.db.Create(&s.student).Error; err != nil {
t.Errorf("Failed to insert to gorm db, got error: %v", err)
}
err = s.mock.ExpectationsWereMet()
if err != nil {
t.Errorf("Failed to meet expectations, got error: %v", err)
}
}