C# T4模板生成实体类

  • A+

C# T4模板生成实体类

一、前言

俗话说得好,懒是第一生产力。一个数据库表的字段太多,写的真是手抽筋。就想着怎么才能偷个懒。

如果你对T4基础语法不是很了解,可以参考我前面写的一篇文章 T4语法快速入门

二、原理

我们要做的事情是通过数据库表生成实体类。

第一步 我们要查询出当前用户下的所有数据库表。

第二步 查询出数据库表的结构,比如字段的名称,字段的类型,字段的长度大小,是否为空等等。

三、新建DbHelper.ttinclude和ModelAuto.ttinclude

DbHelper.ttinclude

  1. <#+  
  2.     public class config  
  3.     {  
  4.         public static readonly string ConnectionString = "server=.;database=Test;uid=sa;pwd=123";  
  5.         public static readonly string DbDatabase = "";  
  6.         public static readonly string TableName = "";  
  7.     }  
  8. #>  
  9. <#+   
  10.     public class DbHelper  
  11.     {  
  12.         #region GetDbTables  
  13.   
  14.          public static List<string> GetDbTablesNew(string connectionString, string database,string tables = null)  
  15.         {   
  16.              if (!string.IsNullOrEmpty(tables))  
  17.             {  
  18.                 tables = string.Format(" and obj.name in ('{0}')", tables.Replace(",""','"));  
  19.             }  
  20.             string sql = string.Format(@"SELECT 
  21.                 obj.name tablename 
  22.                 from {0}.sys.objects obj  
  23.                 inner join {0}.dbo.sysindexes idx on obj.object_id=idx.id and idx.indid<=1 
  24.                 INNER JOIN {0}.sys.schemas schem ON obj.schema_id=schem.schema_id 
  25.                 left join {0}.sys.extended_properties g ON (obj.object_id = g.major_id AND g.minor_id = 0 AND g.name= 'MS_Description') 
  26.                 where type='U' {1}  
  27.                 order by obj.name", database,tables);  
  28.             DataTable dt = GetDataTable(connectionString, sql);  
  29.             return dt.Rows.Cast<DataRow>().Select(row =>row.Field<string>("tablename")).ToList();  
  30.         }  
  31.   
  32.         public static List<DbTable> GetDbTables(string connectionString, string database, string tables = null)  
  33.         {  
  34.   
  35.             if (!string.IsNullOrEmpty(tables))  
  36.             {  
  37.                 tables = string.Format(" and obj.name in ('{0}')", tables.Replace(",""','"));  
  38.             }  
  39.             #region SQL  
  40.             string sql = string.Format(@"SELECT 
  41.                                     obj.name tablename, 
  42.                                     schem.name schemname, 
  43.                                     idx.rows, 
  44.                                     CAST 
  45.                                     ( 
  46.                                         CASE  
  47.                                             WHEN (SELECT COUNT(1) FROM sys.indexes WHERE object_id= obj.OBJECT_ID AND is_primary_key=1) >=1 THEN 1 
  48.                                             ELSE 0 
  49.                                         END  
  50.                                     AS BIT) HasPrimaryKey                                          
  51.                                     from {0}.sys.objects obj  
  52.                                     inner join {0}.dbo.sysindexes idx on obj.object_id=idx.id and idx.indid<=1 
  53.                                     INNER JOIN {0}.sys.schemas schem ON obj.schema_id=schem.schema_id 
  54.                                     where type='U' {1} 
  55.                                     order by obj.name", database, tables);  
  56.             #endregion  
  57.             DataTable dt = GetDataTable(connectionString, sql);  
  58.             return dt.Rows.Cast<DataRow>().Select(row => new DbTable  
  59.             {  
  60.                 TableName = row.Field<string>("tablename"),  
  61.                 SchemaName = row.Field<string>("schemname"),  
  62.                 Rows = row.Field<int>("rows"),  
  63.                 HasPrimaryKey = row.Field<bool>("HasPrimaryKey")  
  64.             }).ToList();  
  65.         }  
  66.         #endregion  
  67.  
  68.         #region GetDbColumns  
  69.   
  70.         public static List<DbColumn> GetDbColumns(string connectionString, string database, string tableName, string schema = "dbo")  
  71.         {  
  72.             #region SQL  
  73.             string sql = string.Format(@" 
  74.                                     WITH indexCTE AS 
  75.                                     ( 
  76.                                         SELECT  
  77.                                         ic.column_id, 
  78.                                         ic.index_column_id, 
  79.                                         ic.object_id     
  80.                                         FROM {0}.sys.indexes idx 
  81.                                         INNER JOIN {0}.sys.index_columns ic ON idx.index_id = ic.index_id AND idx.object_id = ic.object_id 
  82.                                         WHERE  idx.object_id =OBJECT_ID(@tableName) AND idx.is_primary_key=1 
  83.                                     ) 
  84.                                     select 
  85.                                     colm.column_id ColumnID, 
  86.                                     CAST(CASE WHEN indexCTE.column_id IS NULL THEN 0 ELSE 1 END AS BIT) IsPrimaryKey, 
  87.                                     colm.name ColumnName, 
  88.                                     systype.name ColumnType, 
  89.                                     colm.is_identity IsIdentity, 
  90.                                     colm.is_nullable IsNullable, 
  91.                                     cast(colm.max_length as int) ByteLength, 
  92.                                     ( 
  93.                                         case  
  94.                                             when systype.name='nvarchar' and colm.max_length>0 then colm.max_length/2  
  95.                                             when systype.name='nchar' and colm.max_length>0 then colm.max_length/2 
  96.                                             when systype.name='ntext' and colm.max_length>0 then colm.max_length/2  
  97.                                             else colm.max_length 
  98.                                         end 
  99.                                     ) CharLength, 
  100.                                     cast(colm.precision as int) Precision, 
  101.                                     cast(colm.scale as int) Scale, 
  102.                                     prop.value Remark 
  103.                                     from {0}.sys.columns colm 
  104.                                     inner join {0}.sys.types systype on colm.system_type_id=systype.system_type_id and colm.user_type_id=systype.user_type_id 
  105.                                     left join {0}.sys.extended_properties prop on colm.object_id=prop.major_id and colm.column_id=prop.minor_id 
  106.                                     LEFT JOIN indexCTE ON colm.column_id=indexCTE.column_id AND colm.object_id=indexCTE.object_id                                         
  107.                                     where colm.object_id=OBJECT_ID(@tableName) 
  108.                                     order by colm.column_id", database);  
  109.             #endregion  
  110.             SqlParameter param = new SqlParameter("@tableName", SqlDbType.NVarChar, 100) { Value = string.Format("{0}.{1}.{2}", database, schema, tableName) };  
  111.             DataTable dt = GetDataTable(connectionString, sql, param);  
  112.             return dt.Rows.Cast<DataRow>().Select(row => new DbColumn()  
  113.             {  
  114.                 ColumnID = row.Field<int>("ColumnID"),  
  115.                 IsPrimaryKey = row.Field<bool>("IsPrimaryKey"),  
  116.                 ColumnName = row.Field<string>("ColumnName"),  
  117.                 ColumnType = row.Field<string>("ColumnType"),  
  118.                 IsIdentity = row.Field<bool>("IsIdentity"),  
  119.                 IsNullable = row.Field<bool>("IsNullable"),  
  120.                 ByteLength = row.Field<int>("ByteLength"),  
  121.                 CharLength = row.Field<int>("CharLength"),  
  122.                 Precision=row.Field<int>("Precision"),  
  123.                 Scale = row.Field<int>("Scale"),  
  124.                 Remark = row["Remark"].ToString()  
  125.             }).ToList();  
  126.         }  
  127.  
  128.         #endregion  
  129.  
  130.         #region GetDataTable  
  131.   
  132.         public static DataTable GetDataTable(string connectionString, string commandText, params SqlParameter[] parms)  
  133.         {  
  134.             using (SqlConnection connection = new SqlConnection(connectionString))  
  135.             {  
  136.                 SqlCommand command = connection.CreateCommand();  
  137.                 command.CommandText = commandText;  
  138.                 command.Parameters.AddRange(parms);  
  139.                 SqlDataAdapter adapter = new SqlDataAdapter(command);  
  140.   
  141.                 DataTable dt = new DataTable();  
  142.                 adapter.Fill(dt);  
  143.   
  144.                 return dt;  
  145.             }  
  146.         }  
  147.  
  148.         #endregion  
  149.  
  150.         #region GetPrimaryKey  
  151.         public static string GetPrimaryKey(List<DbColumn> dbColumns)  
  152.         {  
  153.             string primaryKey = string.Empty;  
  154.             if (dbColumns!=null&&dbColumns.Count>0)  
  155.             {  
  156.                 foreach (var item in dbColumns)  
  157.                 {  
  158.                     if (item.IsPrimaryKey==true)  
  159.                     {  
  160.                         primaryKey = item.ColumnName;  
  161.                     }  
  162.                 }  
  163.             }  
  164.             return primaryKey;  
  165.         }  
  166.         #endregion  
  167.   
  168.         public static string StrToUpper(string Str)  
  169.         {  
  170.             string[] strArray = Str.ToLower().Split('_');  
  171.             string result = "";//定义一个空字符串  
  172.             foreach (string item in strArray)//循环处理数组里面每一个字符串  
  173.             {  
  174.                 result += item.Substring(0, 1).ToUpper() + item.Substring(1);  
  175.                 //.Substring(0, 1).ToUpper()把循环到的字符串第一个字母截取并转换为大写  
  176.                 //并用s.Substring(1)得到循环到的字符串除第一个字符后的所有字符拼装到首字母后面。  
  177.             }  
  178.             return result;  
  179.         }  
  180.     }  
  181.  
  182.     #region DbTable  
  183.     public sealed class DbTable  
  184.     {  
  185.         public string TableName { getset; }  
  186.         public string SchemaName { getset; }  
  187.         public int Rows { getset; }  
  188.   
  189.         public bool HasPrimaryKey { getset; }  
  190.     }  
  191.     #endregion  
  192.  
  193.     #region DbColumn  
  194.       
  195.     public sealed class DbColumn  
  196.     {  
  197.           
  198.         public int ColumnID { getset; }  
  199.   
  200.          
  201.         public bool IsPrimaryKey { getset; }  
  202.   
  203.           
  204.         public string ColumnName { getset; }  
  205.   
  206.          
  207.         public string ColumnType { getset; }  
  208.   
  209.          
  210.         public string CSharpType  
  211.         {  
  212.             get  
  213.             {  
  214.                 return SqlServerDbTypeMap.MapCsharpType(ColumnType);  
  215.             }  
  216.         }  
  217.   
  218.         /// <summary>  
  219.         ///   
  220.         /// </summary>  
  221.         public Type CommonType  
  222.         {  
  223.             get  
  224.             {  
  225.                 return SqlServerDbTypeMap.MapCommonType(ColumnType);  
  226.             }  
  227.         }  
  228.   
  229.         public int ByteLength { getset; }  
  230.   
  231.         public int CharLength { getset; }  
  232.   
  233.         public int Precision{get;set;}  
  234.         public int Scale { getset; }  
  235.   
  236.         public bool IsIdentity { getset; }  
  237.   
  238.         public bool IsNullable { getset; }  
  239.   
  240.         public string Remark { getset; }  
  241.     }  
  242.     #endregion  
  243.  
  244.     #region SqlServerDbTypeMap  
  245.   
  246.     public class SqlServerDbTypeMap  
  247.     {  
  248.         public static string MapCsharpType(string dbtype)  
  249.         {  
  250.             if (string.IsNullOrEmpty(dbtype)) return dbtype;  
  251.             dbtype = dbtype.ToLower();  
  252.             string csharpType = "object";  
  253.             switch (dbtype)  
  254.             {  
  255.                 case "bigint": csharpType = "long"break;  
  256.                 case "binary": csharpType = "byte[]"break;  
  257.                 case "bit": csharpType = "bool"break;  
  258.                 case "char": csharpType = "string"break;  
  259.                 case "date": csharpType = "DateTime"break;  
  260.                 case "datetime": csharpType = "DateTime"break;  
  261.                 case "datetime2": csharpType = "DateTime"break;  
  262.                 case "datetimeoffset": csharpType = "DateTimeOffset"break;  
  263.                 case "decimal": csharpType = "decimal"break;  
  264.                 case "float": csharpType = "double"break;  
  265.                 case "image": csharpType = "byte[]"break;  
  266.                 case "int": csharpType = "int"break;  
  267.                 case "money": csharpType = "decimal"break;  
  268.                 case "nchar": csharpType = "string"break;  
  269.                 case "ntext": csharpType = "string"break;  
  270.                 case "numeric": csharpType = "decimal"break;  
  271.                 case "nvarchar": csharpType = "string"break;  
  272.                 case "real": csharpType = "Single"break;  
  273.                 case "smalldatetime": csharpType = "DateTime"break;  
  274.                 case "smallint": csharpType = "short"break;  
  275.                 case "smallmoney": csharpType = "decimal"break;  
  276.                 case "sql_variant": csharpType = "object"break;  
  277.                 case "sysname": csharpType = "object"break;  
  278.                 case "text": csharpType = "string"break;  
  279.                 case "time": csharpType = "TimeSpan"break;  
  280.                 case "timestamp": csharpType = "byte[]"break;  
  281.                 case "tinyint": csharpType = "byte"break;  
  282.                 case "uniqueidentifier": csharpType = "Guid"break;  
  283.                 case "varbinary": csharpType = "byte[]"break;  
  284.                 case "varchar": csharpType = "string"break;  
  285.                 case "xml": csharpType = "string"break;  
  286.                 default: csharpType = "object"break;  
  287.             }  
  288.             return csharpType;  
  289.         }  
  290.   
  291.         public static Type MapCommonType(string dbtype)  
  292.         {  
  293.             if (string.IsNullOrEmpty(dbtype)) return Type.Missing.GetType();  
  294.             dbtype = dbtype.ToLower();  
  295.             Type commonType = typeof(object);  
  296.             switch (dbtype)  
  297.             {  
  298.                 case "bigint": commonType = typeof(long); break;  
  299.                 case "binary": commonType = typeof(byte[]); break;  
  300.                 case "bit": commonType = typeof(bool); break;  
  301.                 case "char": commonType = typeof(string); break;  
  302.                 case "date": commonType = typeof(DateTime); break;  
  303.                 case "datetime": commonType = typeof(DateTime); break;  
  304.                 case "datetime2": commonType = typeof(DateTime); break;  
  305.                 case "datetimeoffset": commonType = typeof(DateTimeOffset); break;  
  306.                 case "decimal": commonType = typeof(decimal); break;  
  307.                 case "float": commonType = typeof(double); break;  
  308.                 case "image": commonType = typeof(byte[]); break;  
  309.                 case "int": commonType = typeof(int); break;  
  310.                 case "money": commonType = typeof(decimal); break;  
  311.                 case "nchar": commonType = typeof(string); break;  
  312.                 case "ntext": commonType = typeof(string); break;  
  313.                 case "numeric": commonType = typeof(decimal); break;  
  314.                 case "nvarchar": commonType = typeof(string); break;  
  315.                 case "real": commonType = typeof(Single); break;  
  316.                 case "smalldatetime": commonType = typeof(DateTime); break;  
  317.                 case "smallint": commonType = typeof(short); break;  
  318.                 case "smallmoney": commonType = typeof(decimal); break;  
  319.                 case "sql_variant": commonType = typeof(object); break;  
  320.                 case "sysname": commonType = typeof(object); break;  
  321.                 case "text": commonType = typeof(string); break;  
  322.                 case "time": commonType = typeof(TimeSpan); break;  
  323.                 case "timestamp": commonType = typeof(byte[]); break;  
  324.                 case "tinyint": commonType = typeof(byte); break;  
  325.                 case "uniqueidentifier": commonType = typeof(Guid); break;  
  326.                 case "varbinary": commonType = typeof(byte[]); break;  
  327.                 case "varchar": commonType = typeof(string); break;  
  328.                 case "xml": commonType = typeof(string); break;  
  329.                 default: commonType = typeof(object); break;  
  330.             }  
  331.             return commonType;  
  332.         }  
  333.     }  
  334.     #endregion  
  335.  #>  

ModelAuto.ttinclude

  1. <#@ assembly name="System.Core"#>  
  2. <#@ assembly name="EnvDTE"#>  
  3. <#@ import namespace="System.Collections.Generic"#>  
  4. <#@ import namespace="System.IO"#>  
  5. <#@ import namespace="System.Text"#>  
  6. <#@ import namespace="Microsoft.VisualStudio.TextTemplating"#>  
  7. <#+  
  8. class Manager  
  9. {  
  10.     public struct Block {  
  11.         public int Start, Length;  
  12.         public String Name,OutputPath;  
  13.     }  
  14.   
  15.     public List<Block> blocks = new List<Block>();  
  16.     public Block currentBlock;  
  17.     public Block footerBlock = new Block();  
  18.     public Block headerBlock = new Block();  
  19.     public ITextTemplatingEngineHost host;  
  20.     public ManagementStrategy strategy;  
  21.     public StringBuilder template;  
  22.     public Manager(ITextTemplatingEngineHost host, StringBuilder template, bool commonHeader) {  
  23.         this.host = host;  
  24.         this.template = template;  
  25.         strategy = ManagementStrategy.Create(host);  
  26.     }  
  27.     public void StartBlock(String name,String outputPath) {  
  28.         currentBlock = new Block { Name = name, Start = template.Length ,OutputPath=outputPath};  
  29.     }  
  30.   
  31.     public void StartFooter() {  
  32.         footerBlock.Start = template.Length;  
  33.     }  
  34.   
  35.     public void EndFooter() {  
  36.         footerBlock.Length = template.Length - footerBlock.Start;  
  37.     }  
  38.   
  39.     public void StartHeader() {  
  40.         headerBlock.Start = template.Length;  
  41.     }  
  42.   
  43.     public void EndHeader() {  
  44.         headerBlock.Length = template.Length - headerBlock.Start;  
  45.     }      
  46.   
  47.     public void EndBlock() {  
  48.         currentBlock.Length = template.Length - currentBlock.Start;  
  49.         blocks.Add(currentBlock);  
  50.     }  
  51.     public void Process(bool split) {  
  52.         String header = template.ToString(headerBlock.Start, headerBlock.Length);  
  53.         String footer = template.ToString(footerBlock.Start, footerBlock.Length);  
  54.         blocks.Reverse();  
  55.         foreach(Block block in blocks) {  
  56.             String fileName = Path.Combine(block.OutputPath, block.Name);  
  57.             if (split) {  
  58.                 String content = header + template.ToString(block.Start, block.Length) + footer;  
  59.                 strategy.CreateFile(fileName, content);  
  60.                 template.Remove(block.Start, block.Length);  
  61.             } else {  
  62.                 strategy.DeleteFile(fileName);  
  63.             }  
  64.         }  
  65.     }  
  66. }  
  67. class ManagementStrategy  
  68. {  
  69.     internal static ManagementStrategy Create(ITextTemplatingEngineHost host) {  
  70.         return (host is IServiceProvider) ? new VSManagementStrategy(host) : new ManagementStrategy(host);  
  71.     }  
  72.   
  73.     internal ManagementStrategy(ITextTemplatingEngineHost host) { }  
  74.   
  75.     internal virtual void CreateFile(String fileName, String content) {  
  76.         File.WriteAllText(fileName, content);  
  77.     }  
  78.   
  79.     internal virtual void DeleteFile(String fileName) {  
  80.         if (File.Exists(fileName))  
  81.             File.Delete(fileName);  
  82.     }  
  83. }  
  84.   
  85. class VSManagementStrategy : ManagementStrategy  
  86. {  
  87.     private EnvDTE.ProjectItem templateProjectItem;  
  88.   
  89.     internal VSManagementStrategy(ITextTemplatingEngineHost host) : base(host) {  
  90.         IServiceProvider hostServiceProvider = (IServiceProvider)host;  
  91.         if (hostServiceProvider == null)  
  92.             throw new ArgumentNullException("Could not obtain hostServiceProvider");  
  93.   
  94.         EnvDTE.DTE dte = (EnvDTE.DTE)hostServiceProvider.GetService(typeof(EnvDTE.DTE));  
  95.         if (dte == null)  
  96.             throw new ArgumentNullException("Could not obtain DTE from host");  
  97.   
  98.         templateProjectItem = dte.Solution.FindProjectItem(host.TemplateFile);  
  99.     }  
  100.     internal override void CreateFile(String fileName, String content) {  
  101.         base.CreateFile(fileName, content);  
  102.         //((EventHandler)delegate { templateProjectItem.ProjectItems.AddFromFile(fileName); }).BeginInvoke(null, null, null, null);  
  103.     }  
  104.     internal override void DeleteFile(String fileName) {  
  105.         ((EventHandler)delegate { FindAndDeleteFile(fileName); }).BeginInvoke(nullnullnullnull);  
  106.     }  
  107.     private void FindAndDeleteFile(String fileName) {  
  108.         foreach(EnvDTE.ProjectItem projectItem in templateProjectItem.ProjectItems) {  
  109.             if (projectItem.get_FileNames(0) == fileName) {  
  110.                 projectItem.Delete();  
  111.                 return;  
  112.             }  
  113.         }  
  114.     }  
  115. }#>  

四、新建一个Model.tt

  1. <#@ template debug="true" hostspecific="true" language="C#" #>  
  2. <#@ output extension=".cs" #>  
  3. <#@ assembly name="System.Core.dll" #>  
  4. <#@ assembly name="System.Data.dll" #>  
  5. <#@ assembly name="System.Data.DataSetExtensions.dll" #>  
  6. <#@ assembly name="System.Xml.dll" #>  
  7. <#@ import namespace="System" #>  
  8. <#@ import namespace="System.Xml" #>  
  9. <#@ import namespace="System.Linq" #>  
  10. <#@ import namespace="System.Data" #>  
  11. <#@ import namespace="System.Data.SqlClient" #>  
  12. <#@ import namespace="System.Collections.Generic" #>  
  13. <#@ import namespace="System.IO" #>  
  14. <#@ include file="$(ProjectDir)/T4/DbHelper.ttinclude"  #>  
  15. <#@ include file="$(ProjectDir)/T4/ModelAuto.ttinclude" #>  
  16. <# var manager = new Manager(Host, GenerationEnvironment, true); #>  
  17. <#   
  18.     var OutputPath1 =Path.GetDirectoryName(Host.TemplateFile);  
  19.     OutputPath1=Path.Combine(OutputPath1,"Template","Models");  
  20.     if (!Directory.Exists(OutputPath1))  
  21.     {  
  22.         Directory.CreateDirectory(OutputPath1);  
  23.     }  
  24. #>  
  25. <# foreach (var item in DbHelper.GetDbTablesNew(config.ConnectionString, config.DbDatabase,""))  
  26.    {  
  27.         var tableName=DbHelper.StrToUpper(item);  
  28.         manager.StartBlock(tableName+".cs",OutputPath1);//文件名  
  29. #>  
  30. using System;  
  31. using System.ComponentModel.DataAnnotations;  
  32. using System.ComponentModel.DataAnnotations.Schema;  
  33. namespace CommonLibrary.Model  
  34. {  
  35.      ///<summary>  
  36.      ///<#=tableName#>  
  37.      ///</summary>  
  38.      public class <#=tableName#>  
  39.      {  
  40.      <# foreach(DbColumn column in DbHelper.GetDbColumns(config.ConnectionString, config.DbDatabase, tableName)){#>  
  41.         /// <summary>  
  42.         /// <#= column.Remark == "" ? column.ColumnName : column.Remark.Replace("\r\n"," ") #>  
  43.         /// </summary>  
  44.         [Column("<#= column.ColumnName#>")]  
  45.         public <#= column.CSharpType#><# if(column.CommonType.IsValueType && column.IsNullable){#>?<#}#> <#=DbHelper.StrToUpper(column.ColumnName)#> { getset; }  
  46.         <# } #>   
  47.      }  
  48.       
  49.     public class <#=tableName#>List  
  50.     {  
  51.         public List<<#=tableName#>> List { getset; }  
  52.         public int? TotalSize { getset; }  
  53.     }  
  54. }      
  55. <#  
  56.    manager.EndBlock();   
  57.    }  
  58.    manager.Process(true);  
  59. #>  

效果图

C# T4模板生成实体类

C# T4模板生成实体类

总结

好了,现在就可以自动生成实体类了。

T4不仅仅可以生成实体类 只要你够懒它就能帮你干更多的事

钰玺

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: